PengLiu commited on
Commit
56ef371
·
1 Parent(s): ce3eb8e

push inference code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +3 -4
  2. demo/gradio_demo.py +253 -0
  3. detect_tools/upn/__init__.py +45 -0
  4. detect_tools/upn/builder.py +39 -0
  5. detect_tools/upn/configs/upn_large.py +73 -0
  6. detect_tools/upn/inference_wrapper.py +237 -0
  7. detect_tools/upn/models/architecture/__init__.py +4 -0
  8. detect_tools/upn/models/architecture/deformable_transformer.py +336 -0
  9. detect_tools/upn/models/architecture/upn_model.py +343 -0
  10. detect_tools/upn/models/backbone/__init__.py +4 -0
  11. detect_tools/upn/models/backbone/swin.py +814 -0
  12. detect_tools/upn/models/backbone/wrapper.py +297 -0
  13. detect_tools/upn/models/decoder/__init__.py +3 -0
  14. detect_tools/upn/models/decoder/upn_decoder.py +378 -0
  15. detect_tools/upn/models/encoder/__init__.py +3 -0
  16. detect_tools/upn/models/encoder/upn_encoder.py +288 -0
  17. detect_tools/upn/models/module/__init__.py +5 -0
  18. detect_tools/upn/models/module/contrastive.py +29 -0
  19. detect_tools/upn/models/module/mlp.py +18 -0
  20. detect_tools/upn/models/module/nested_tensor.py +199 -0
  21. detect_tools/upn/models/utils/__init__.py +23 -0
  22. detect_tools/upn/models/utils/detr_utils.py +415 -0
  23. detect_tools/upn/ops/functions/__init__.py +10 -0
  24. detect_tools/upn/ops/functions/ms_deform_attn_func.py +61 -0
  25. detect_tools/upn/ops/modules/__init__.py +9 -0
  26. detect_tools/upn/ops/modules/ms_deform_attn.py +204 -0
  27. detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py +130 -0
  28. detect_tools/upn/ops/setup.py +73 -0
  29. detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
  30. detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
  31. detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
  32. detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
  33. detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  34. detect_tools/upn/ops/src/ms_deform_attn.h +62 -0
  35. detect_tools/upn/ops/src/vision.cpp +16 -0
  36. detect_tools/upn/ops/test.py +89 -0
  37. detect_tools/upn/transforms/transform.py +142 -0
  38. requirements.txt +10 -0
  39. run.sh +23 -0
  40. vlm_fo1/__init__.py +1 -0
  41. vlm_fo1/constants.py +29 -0
  42. vlm_fo1/mm_utils.py +658 -0
  43. vlm_fo1/model/__init__.py +1 -0
  44. vlm_fo1/model/builder.py +143 -0
  45. vlm_fo1/model/language_model/omchat_qwen2_5_vl.py +576 -0
  46. vlm_fo1/model/multimodal_encoder/__init__.py +0 -0
  47. vlm_fo1/model/multimodal_encoder/base_encoder.py +33 -0
  48. vlm_fo1/model/multimodal_encoder/builder.py +38 -0
  49. vlm_fo1/model/multimodal_encoder/davit/configs.py +152 -0
  50. vlm_fo1/model/multimodal_encoder/davit/configuration_davit.py +119 -0
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
  title: VLM FO1 3B Demo
3
- emoji: 📊
4
  colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: VLM-FO1-3B-Demo
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: VLM FO1 3B Demo
3
+ emoji: 🐠
4
  colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
+ app_file: run.sh
9
  pinned: false
10
  license: apache-2.0
11
  short_description: VLM-FO1-3B-Demo
12
  ---
13
 
 
demo/gradio_demo.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import re
4
+ import numpy as np
5
+ from skimage.measure import label, regionprops
6
+ from skimage.morphology import binary_dilation, disk
7
+ from detect_tools.upn import UPNWrapper
8
+ from vlm_fo1.model.builder import load_pretrained_model
9
+ from vlm_fo1.mm_utils import (
10
+ prepare_inputs,
11
+ extract_predictions_to_indexes,
12
+ )
13
+ from vlm_fo1.task_templates import *
14
+ import torch
15
+
16
+
17
+ TASK_TYPES = {
18
+ "OD/REC": OD_template,
19
+ "ODCounting": OD_Counting_template,
20
+ "Region_OCR": "Please provide the ocr results of these regions in the image.",
21
+ "Brief_Region_Caption": "Provide a brief description for these regions in the image.",
22
+ "Detailed_Region_Caption": "Provide a detailed description for these regions in the image.",
23
+ "Grounding": Grounding_template,
24
+ "Viusal_Region_Reasoning": Viusal_Region_Reasoning_template,
25
+ }
26
+
27
+
28
+
29
+ def detect_model(image, threshold=0.3):
30
+ proposals = upn_model.inference(image)
31
+ filtered_proposals = upn_model.filter(proposals, min_score=threshold)
32
+ return filtered_proposals['original_xyxy_boxes'][0][:100]
33
+
34
+
35
+ def multimodal_model(image, bboxes, text):
36
+ if '<image>' in text:
37
+ print(text)
38
+ parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()]
39
+ print(parts)
40
+ content = []
41
+ for part in parts:
42
+ if part == '<image>':
43
+ content.append({"type": "image_url", "image_url": {"url": image}})
44
+ else:
45
+ content.append({"type": "text", "text": part})
46
+ else:
47
+ content = [{
48
+ "type": "image_url",
49
+ "image_url": {
50
+ "url": image
51
+ }
52
+ }, {
53
+ "type": "text",
54
+ "text": text
55
+ }]
56
+
57
+ messages = [
58
+ {
59
+ "role": "user",
60
+ "content": content,
61
+ "bbox_list": bboxes
62
+ }
63
+ ]
64
+ generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages,
65
+ max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False)
66
+ with torch.inference_mode():
67
+ output_ids = model.generate(**generation_kwargs)
68
+ outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
69
+ print("========output========\n", outputs)
70
+
71
+ prediction_dict = extract_predictions_to_indexes(outputs)
72
+
73
+ ans_bbox_json = []
74
+ ans_bbox_list = []
75
+ for k, v in prediction_dict.items():
76
+ for box_index in v:
77
+ box_index = int(box_index)
78
+ if box_index < len(bboxes):
79
+ current_bbox = bboxes[box_index]
80
+ ans_bbox_json.append({
81
+ "region_index": f"<region{box_index}>",
82
+ "xmin": current_bbox[0],
83
+ "ymin": current_bbox[1],
84
+ "xmax": current_bbox[2],
85
+ "ymax": current_bbox[3],
86
+ "label": k
87
+ })
88
+ ans_bbox_list.append(current_bbox)
89
+
90
+ return outputs, ans_bbox_json, ans_bbox_list
91
+
92
+
93
+
94
+ def draw_bboxes(image, bboxes, labels=None):
95
+ image = image.copy()
96
+ draw = ImageDraw.Draw(image)
97
+
98
+ for bbox in bboxes:
99
+ draw.rectangle(bbox, outline="red", width=3)
100
+ return image
101
+
102
+
103
+ def extract_bbox_and_original_image(edited_image: dict):
104
+ original_image = edited_image["background"]
105
+ bbox_list = []
106
+
107
+ if original_image is None:
108
+ return None, "Error, Please upload an image."
109
+
110
+ if edited_image["layers"] is None or len(edited_image["layers"]) == 0:
111
+ return original_image, []
112
+
113
+ drawing_layer = edited_image["layers"][0]
114
+ alpha_channel = drawing_layer.getchannel('A')
115
+ alpha_np = np.array(alpha_channel)
116
+
117
+ binary_mask = alpha_np > 0
118
+
119
+ structuring_element = disk(5)
120
+ dilated_mask = binary_dilation(binary_mask, structuring_element)
121
+
122
+ labeled_image = label(dilated_mask)
123
+ regions = regionprops(labeled_image)
124
+
125
+ for prop in regions:
126
+ y_min, x_min, y_max, x_max = prop.bbox
127
+ bbox_list.append((x_min, y_min, x_max, y_max))
128
+
129
+ return original_image, bbox_list
130
+
131
+
132
+ def process(image, prompt, threshold):
133
+ image, bbox_list = extract_bbox_and_original_image(image)
134
+ image = image.convert('RGB')
135
+
136
+ if len(bbox_list) == 0:
137
+ # Get bboxes from detection model
138
+ bboxes = detect_model(image, threshold)
139
+ else:
140
+ bboxes = bbox_list
141
+ for idx in range(len(bboxes)):
142
+ prompt += f'<region{idx}>'
143
+
144
+ ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt)
145
+
146
+
147
+ image_with_opn = draw_bboxes(image, bboxes)
148
+
149
+ annotated_bboxes = []
150
+ if len(ans_bbox_json) > 0:
151
+ for item in ans_bbox_json:
152
+ annotated_bboxes.append(
153
+ ((int(item['xmin']), int(item['ymin']), int(item['xmax']), int(item['ymax'])), item['label'])
154
+ )
155
+ annotated_image = (image, annotated_bboxes)
156
+
157
+ return annotated_image, image_with_opn, ans, ans_bbox_json
158
+
159
+
160
+ def show_label_input(choice):
161
+ return gr.update(visible=(choice == "OmDet"))
162
+
163
+
164
+ def update_btn(is_processing):
165
+ if is_processing:
166
+ return gr.update(value="Processing...", interactive=False)
167
+ else:
168
+ return gr.update(value="Submit", interactive=True)
169
+
170
+
171
+ def launch_demo():
172
+ with gr.Blocks() as demo:
173
+ gr.Markdown("## VLM-FO1 Demo")
174
+ gr.Markdown("""
175
+ **Instructions:**
176
+ 1. Upload an image, then you can either draw circular regions on it using the red brush as the input regions or let the detection model detect the regions for you.
177
+ 2. Select a task template and replace the [WRITE YOUR INPUT HERE] with your input targets, or write your own prompt.\n
178
+ For example, if you want to detect "person" and "dog", you can replace the [WRITE YOUR INPUT HERE] with "person, dog".\n
179
+ 3. Adjust the detection threshold if needed
180
+ 4. Click Submit to get results
181
+ """)
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ img_input_draw = gr.ImageEditor(
186
+ label="Image Input",
187
+ image_mode="RGBA",
188
+ type="pil",
189
+ sources=['upload'],
190
+ brush=gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2),
191
+ interactive=True
192
+ )
193
+
194
+ gr.Markdown("### Prompt & Parameters")
195
+
196
+ def set_prompt_from_template(selected_task):
197
+ return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]"))
198
+
199
+ task_type_input = gr.Dropdown(
200
+ choices=list(TASK_TYPES.keys()),
201
+ value="OD/REC",
202
+ label="Prompt Templates",
203
+ info="Select the prompt template for the task, or write your own prompt."
204
+ )
205
+
206
+ prompt_input = gr.Textbox(
207
+ label="Task Prompt",
208
+ value=TASK_TYPES["OD/REC"].format("[WRITE YOUR INPUT HERE]"),
209
+ lines=2,
210
+ )
211
+
212
+ task_type_input.change(
213
+ set_prompt_from_template,
214
+ inputs=task_type_input,
215
+ outputs=prompt_input
216
+ )
217
+
218
+
219
+ threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold")
220
+ submit_btn = gr.Button("Submit", variant="primary")
221
+
222
+ with gr.Column():
223
+ with gr.Accordion("Detection Result", open=True):
224
+ image_output_opn = gr.Image(label="Detection Result")
225
+
226
+ image_output = gr.AnnotatedImage(label="Multimodal Model Output", height=500)
227
+
228
+ result_output = gr.Textbox(label="Multimodal Model Output")
229
+ ans_bbox_json = gr.JSON(label="Extracted Detection Output")
230
+
231
+ submit_btn.click(update_btn, inputs=[gr.State(True)], outputs=[submit_btn], queue=False).then(
232
+ process,
233
+ inputs=[img_input_draw, prompt_input, threshold_input],
234
+ outputs=[image_output, image_output_opn, result_output, ans_bbox_json],
235
+ queue=True
236
+ ).then(update_btn, inputs=[gr.State(False)], outputs=[submit_btn], queue=False)
237
+
238
+ return demo
239
+
240
+ if __name__ == "__main__":
241
+ model_path = './resources/VLM-FO1_Qwen2.5-VL-3B-v01'
242
+ upn_ckpt_path = "./resources/upn_large.pth"
243
+ tokenizer, model, image_processors = load_pretrained_model(
244
+ model_path=model_path,
245
+ device="cuda:0",
246
+ )
247
+ upn_model = UPNWrapper(upn_ckpt_path)
248
+
249
+ demo = launch_demo()
250
+ demo.launch()
251
+
252
+
253
+
detect_tools/upn/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import models
2
+ from .builder import (
3
+ ARCHITECTURES,
4
+ BACKBONES,
5
+ DECODERS,
6
+ ENCODERS,
7
+ POS_EMBEDDINGS,
8
+ build_architecture,
9
+ build_backbone,
10
+ build_decoder,
11
+ build_encoder,
12
+ build_position_embedding,
13
+ )
14
+ from .inference_wrapper import UPNWrapper
15
+ from .models.architecture import *
16
+ from .models.backbone import *
17
+ from .models.decoder import *
18
+ from .models.encoder import *
19
+ from .models.module import *
20
+ from .models.utils import *
21
+
22
+ __all__ = [
23
+ "BACKBONES",
24
+ "POS_EMBEDDINGS",
25
+ "ENCODERS",
26
+ "DECODERS",
27
+ "ARCHITECTURES",
28
+ "build_backbone",
29
+ "build_position_embedding",
30
+ "build_encoder",
31
+ "build_decoder",
32
+ "build_architecture",
33
+ "UPNWrapper",
34
+ ]
35
+
36
+ __all__ += (
37
+ models.module.__all__
38
+ + models.utils.__all__
39
+ + models.architecture.__all__
40
+ + models.backbone.__all__
41
+ + models.encoder.__all__
42
+ + models.decoder.__all__
43
+ + models.module.__all__
44
+ + models.utils.__all__
45
+ )
detect_tools/upn/builder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine import Registry, build_from_cfg
2
+
3
+ BACKBONES = Registry("backbone")
4
+ POS_EMBEDDINGS = Registry("position_embedding")
5
+ FUSERS = Registry("fuser")
6
+ ENCODERS = Registry("encoder")
7
+ DECODERS = Registry("decoder")
8
+ ARCHITECTURES = Registry("architecture")
9
+
10
+
11
+ def build_backbone(cfg):
12
+ """Build encoder."""
13
+ return build_from_cfg(cfg, BACKBONES)
14
+
15
+
16
+ def build_position_embedding(cfg):
17
+ """Build position embedding."""
18
+ return build_from_cfg(cfg, POS_EMBEDDINGS)
19
+
20
+
21
+ def build_fuser(cfg):
22
+ """Build fuser."""
23
+ return build_from_cfg(cfg, FUSERS)
24
+
25
+
26
+ def build_encoder(cfg):
27
+ """Build encoder."""
28
+ return build_from_cfg(cfg, ENCODERS)
29
+
30
+
31
+ def build_decoder(cfg):
32
+ """Build decoder."""
33
+ return build_from_cfg(cfg, DECODERS)
34
+
35
+
36
+ def build_architecture(cfg):
37
+ """Build architecture."""
38
+
39
+ return build_from_cfg(cfg, ARCHITECTURES)
detect_tools/upn/configs/upn_large.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_cfg = dict(
2
+ type="DeformableTransformer",
3
+ num_queries=900,
4
+ encoder_cfg=dict(
5
+ type="UPNEncoder",
6
+ encoder_layer_cfg=dict(
7
+ type="DeformableTransformerEncoderLayer",
8
+ activation="relu",
9
+ d_model=256,
10
+ dropout=0.0,
11
+ d_ffn=2048,
12
+ n_heads=8,
13
+ n_levels=5,
14
+ ),
15
+ d_model=256,
16
+ num_layers=6,
17
+ use_checkpoint=False,
18
+ use_transformer_ckpt=False,
19
+ ),
20
+ decoder_cfg=dict(
21
+ type="UPNDecoder",
22
+ decoder_layer_cfg=dict(
23
+ type="DeformableTransformerDecoderLayer",
24
+ activation="relu",
25
+ d_model=256,
26
+ n_heads=8,
27
+ dropout=0.0,
28
+ d_ffn=2048,
29
+ n_levels=5,
30
+ ),
31
+ d_model=256,
32
+ return_intermediate=True,
33
+ num_layers=6,
34
+ rm_dec_query_scale=True,
35
+ use_detached_boxes_dec_out=False,
36
+ ),
37
+ learnable_tgt_init=True,
38
+ random_refpoints_xy=False,
39
+ num_feature_levels=5,
40
+ two_stage_bbox_embed_share=False,
41
+ two_stage_class_embed_share=False,
42
+ two_stage_keep_all_tokens=False,
43
+ two_stage_learn_wh=False,
44
+ two_stage_type="standard",
45
+ binary_query_selection=False,
46
+ )
47
+
48
+ vision_backbone = dict(
49
+ type="SwinWrapper",
50
+ backbone_cfg="swin_L_384_22k",
51
+ lr_backbone=1e-05,
52
+ dilation=False,
53
+ return_interm_indices=[0, 1, 2, 3],
54
+ backbone_freeze_keywords=None,
55
+ backbone_ckpt_path=None,
56
+ use_checkpoint=False,
57
+ position_embedding_cfg=dict(
58
+ type="PositionEmbeddingSineHW",
59
+ normalize=True,
60
+ num_pos_feats=128,
61
+ temperatureH=20,
62
+ temperatureW=20,
63
+ ),
64
+ )
65
+
66
+ model = dict(
67
+ type="UPN",
68
+ vision_backbone_cfg=vision_backbone,
69
+ transformer_cfg=transformer_cfg,
70
+ num_queries=900,
71
+ dec_pred_bbox_embed_share=True,
72
+ dec_pred_class_embed_share=True,
73
+ )
detect_tools/upn/inference_wrapper.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Dict, List, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmengine import Config
8
+ from PIL import Image
9
+ from torchvision.ops import nms
10
+
11
+ import detect_tools.upn.transforms.transform as T
12
+ from detect_tools.upn import build_architecture
13
+ from detect_tools.upn.models.module import nested_tensor_from_tensor_list
14
+
15
+
16
+ def build_model(
17
+ ckpt_path: str,
18
+ ):
19
+ current_path = os.path.dirname(os.path.abspath(__file__))
20
+ config_file = f"configs/upn_large.py"
21
+ config_path = os.path.join(current_path, config_file)
22
+ model_cfg = Config.fromfile(config_path).model
23
+ model = build_architecture(model_cfg)
24
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
25
+ model.load_state_dict(checkpoint["model"], strict=False)
26
+ return model
27
+
28
+
29
+ class UPNWrapper:
30
+ """A wrapper class for the UPN model.
31
+
32
+ Args:
33
+ ckpt_path (str): The path to the model checkpoint
34
+ """
35
+
36
+ def __init__(self, ckpt_path: str):
37
+
38
+ self.model = build_model(ckpt_path)
39
+ self.model.eval()
40
+ self.model.to("cuda")
41
+
42
+ def inference(
43
+ self,
44
+ image: List[Union[str, Image.Image]],
45
+ prompt_type: str = 'fine_grained_prompt',
46
+ ):
47
+ """Single image prediction.
48
+
49
+ Args:
50
+ image List[Union[str, Image.Image]]: A list of image path or
51
+ PIL.Image.Image object.
52
+ prompt_type (str): The type of prompt to use for the prediction. Choice in
53
+ ['fine_grained_prompt', 'coarse_grained_prompt'].
54
+
55
+ Returns:
56
+ Dict: Return dict in format:
57
+ {
58
+ "original_xyxy_boxes": (np.ndarray): Original prediction boxes in shape (batch_size, 900, 4),
59
+ "scores": (np.ndarray): Score in shape (batch_size, N)
60
+ }
61
+ """
62
+ if not isinstance(image, list):
63
+ image = [image]
64
+ input_images, image_sizes = self.construct_input(image)
65
+ outputs = self._inference(input_images, prompt_type)
66
+ post_processed_outputs = self.postprocess(outputs, image_sizes)
67
+ return post_processed_outputs
68
+
69
+ def _inference(self, input_images: List[torch.Tensor], prompt_type: str):
70
+ """Inference for T-Rex2
71
+
72
+ Args:
73
+ input_images (List[torch.Tensor]): Transformed Image
74
+
75
+ Retunrs:
76
+ (Dict): Return dict with keys:
77
+ - query_features: (torch.Tensor): Query features in shape (batch_size, N, 256)
78
+ - pred_boxes: (torch.Tensor): Normalized prediction boxes in shape (batch_size, N, 4),
79
+ in cxcywh format
80
+ """
81
+ input_images = nested_tensor_from_tensor_list(input_images)
82
+ input_images = input_images.to("cuda")
83
+ with torch.no_grad():
84
+ outputs = self.model(input_images, prompt_type)
85
+ return outputs
86
+
87
+ def construct_input(self, image: List[Union[str, Image.Image]]):
88
+ """Construct input for the model
89
+
90
+ Args:
91
+ image (image: Union[List[Union[str, Image.Image]], torch.Tensor]): A list of image path or
92
+ PIL.Image.Image object. If the length of the list is more than 1, the model w`ill
93
+ perform batch inference.
94
+
95
+ Returns:
96
+ Tuple[torch.Tensor, List[List[int]]]: A tuple containing the
97
+ input images, and the sizes of the input images.
98
+ """
99
+ input_images = []
100
+ image_sizes = []
101
+ for _, img in enumerate(image):
102
+ if isinstance(img, str):
103
+ img = Image.open(img)
104
+ elif isinstance(img, Image.Image):
105
+ img = img
106
+ else:
107
+ raise ValueError(
108
+ "image must be either a string or a PIL.Image.Image object"
109
+ )
110
+ W, H = img.size
111
+ image_sizes.append([H, W])
112
+ # add image in tensor format
113
+ input_images.append(self.transform_image(img))
114
+ return input_images, image_sizes
115
+
116
+ def transform_image(self, image_pil: Image) -> Image:
117
+ """apply a set of transformations to a cv2 load image.
118
+
119
+ Args:
120
+ image_path (str): The path to the image file.
121
+
122
+ Returns:
123
+ Tuple[PIL.Image, torch.Tensor]: A tuple containing the original PIL Image and the
124
+ transformed image as a PyTorch tensor.
125
+ """
126
+ transform = T.Compose(
127
+ [
128
+ T.RandomResize([800], max_size=1333),
129
+ T.ToTensor(),
130
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
131
+ ]
132
+ )
133
+ transformed_image, _ = transform(image_pil, None) # 3, h, w
134
+ return transformed_image
135
+
136
+ def postprocess(
137
+ self,
138
+ outputs: Dict[str, torch.Tensor],
139
+ image_pil_sizes: List[List[int]] = None,
140
+ ):
141
+ boxes = outputs["pred_boxes"].cpu()
142
+ scores = (
143
+ outputs["pred_logits"].sigmoid().cpu() if "pred_logits" in outputs else None
144
+ )
145
+ normalized_xyxy_boxes = []
146
+ original_xyxy_boxes = []
147
+ for batch_idx, (H, W) in enumerate(image_pil_sizes):
148
+ batch_boxes = boxes[batch_idx] # (num_queries, 4)
149
+ # from (cx, cy, w, h) to (x1, y1, x2, y2)
150
+ batch_boxes[:, 0] = batch_boxes[:, 0] - batch_boxes[:, 2] / 2
151
+ batch_boxes[:, 1] = batch_boxes[:, 1] - batch_boxes[:, 3] / 2
152
+ batch_boxes[:, 2] = batch_boxes[:, 0] + batch_boxes[:, 2]
153
+ batch_boxes[:, 3] = batch_boxes[:, 1] + batch_boxes[:, 3]
154
+ normalized_xyxy_boxes.append(copy.deepcopy(batch_boxes))
155
+ # scale boxes
156
+ original_boxes = (
157
+ batch_boxes.clone()
158
+ ) # Copy the normalized boxes to scale to original sizes
159
+ original_boxes[:, 0] = original_boxes[:, 0] * W
160
+ original_boxes[:, 1] = original_boxes[:, 1] * H
161
+ original_boxes[:, 2] = original_boxes[:, 2] * W
162
+ original_boxes[:, 3] = original_boxes[:, 3] * H
163
+ original_xyxy_boxes.append(original_boxes)
164
+
165
+ original_xyxy_boxes = torch.stack(original_xyxy_boxes)
166
+ original_xyxy_boxes = original_xyxy_boxes.numpy()
167
+
168
+ # sort everything by score from highest to lowest
169
+ sorted_original_boxes = []
170
+ sorted_scores = []
171
+ for i in range(len(normalized_xyxy_boxes)):
172
+ scores_i = scores[i] if scores is not None else None
173
+ # sort in descending order
174
+ sorted_indices = scores_i.squeeze(-1).argsort(descending=True)
175
+ sorted_original_boxes.append(original_xyxy_boxes[i][sorted_indices])
176
+ sorted_scores.append(scores_i[sorted_indices])
177
+
178
+ original_xyxy_boxes = np.stack(sorted_original_boxes)
179
+ scores = torch.stack(sorted_scores)
180
+
181
+ return dict(
182
+ original_xyxy_boxes=original_xyxy_boxes,
183
+ scores=scores,
184
+ )
185
+
186
+ def filter(self, result: Dict, min_score: float, nms_value: float = 0.8):
187
+ """Filter the UPN detection result. Only keep boxes with score above min_score
188
+ and apply Non-Maximum Suppression (NMS) to filter overlapping boxes.
189
+
190
+ Args:
191
+ result (Dict): A dictionary containing detection results with 'original_xyxy_boxes' and 'scores'.
192
+ min_score (float): Minimum score threshold for keeping a box.
193
+ nms_value (float): NMS threshold for filtering boxes.
194
+
195
+ Returns:
196
+ Dict: Filtered result containing 'original_xyxy_boxes' and 'scores' with the filtered boxes.
197
+ """
198
+ filtered_result = {"original_xyxy_boxes": [], "scores": []}
199
+
200
+ for boxes, scores in zip(
201
+ np.array(result["original_xyxy_boxes"]), result["scores"].numpy()
202
+ ):
203
+ # Filter out boxes with score below min_score
204
+ keep = scores >= min_score
205
+ boxes = boxes[keep[:, 0]]
206
+ scores = scores[keep[:, 0]][:, 0]
207
+
208
+ if len(boxes) == 0:
209
+ return filtered_result
210
+
211
+ # Convert to torch tensors
212
+ boxes = torch.tensor(boxes, dtype=torch.float32)
213
+ scores = torch.tensor(scores, dtype=torch.float32)
214
+
215
+ # Apply Non-Maximum Suppression (NMS)
216
+ if nms_value > 0:
217
+ keep_indices = nms(boxes, scores, nms_value)
218
+ else:
219
+ keep_indices = torch.arange(len(boxes))
220
+
221
+ # Keep only the boxes that passed NMS
222
+ filtered_boxes = boxes[keep_indices].numpy().astype(np.int32)
223
+ filtered_scores = scores[keep_indices].numpy()
224
+
225
+ # Sort the boxes by score in descending order
226
+ sorted_indices = np.argsort(filtered_scores)[::-1]
227
+ filtered_boxes = filtered_boxes[sorted_indices]
228
+ filtered_scores = filtered_scores[sorted_indices]
229
+
230
+ # Round the scores to two decimal places
231
+ filtered_scores = [round(score, 2) for score in filtered_scores]
232
+
233
+ # Store the filtered boxes and scores in the result dictionary
234
+ filtered_result["original_xyxy_boxes"].append(filtered_boxes.tolist())
235
+ filtered_result["scores"].append(filtered_scores)
236
+
237
+ return filtered_result
detect_tools/upn/models/architecture/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .deformable_transformer import DeformableTransformer
2
+ from .upn_model import UPN
3
+
4
+ __all__ = ["UPN", "DeformableTransformer"]
detect_tools/upn/models/architecture/deformable_transformer.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from detect_tools.upn import ARCHITECTURES, build_decoder, build_encoder
8
+ from detect_tools.upn.models.utils import (gen_encoder_output_proposals,
9
+ inverse_sigmoid)
10
+ from detect_tools.upn.ops.modules import MSDeformAttn
11
+
12
+
13
+ @ARCHITECTURES.register_module()
14
+ class DeformableTransformer(nn.Module):
15
+ """Implementation of Deformable DETR.
16
+
17
+ Args:
18
+ encoder_cfg (Dict): Config for the TransformerEncoder.
19
+ decoder_cfg (Dict): Config for the TransformerDecoder.
20
+ num_queries (int): Number of queries. This is for matching part. Default: 900.
21
+ d_model (int): Dimension of the model. Default: 256.
22
+ num_feature_levels (int): Number of feature levels. Default: 1.
23
+ binary_query_selection (bool): Whether to use binary query selection. Default: False.
24
+ When using binary query selection, a linear with out channe =1 will be used to select
25
+ topk proposals. Otherwise, we will use ContrastiveAssign to select topk proposals.
26
+ learnable_tgt_init (bool): Whether to use learnable target init. Default: True. If False,
27
+ we will use the topk encoder features as the target init.
28
+ random_refpoints_xy (bool): Whether to use random refpoints xy. This is only used when
29
+ two_stage_type is not 'no'. Default: False. If True, we will use random refpoints xy.
30
+ two_stage_type (str): Type of two stage. Default: 'standard'. Options: 'no', 'standard'
31
+ two_stage_learn_wh (bool): Whether to learn the width and height of anchor boxes. Default: False.
32
+ two_stage_keep_all_tokens (bool): If False, the returned hs_enc, ref_enc, init_box_proposal
33
+ will only be the topk proposals. Otherwise, we will return all the proposals from the
34
+ encoder. Default: False.
35
+ two_stage_bbox_embed_share (bool): Whether to share the bbox embedding between the two stages.
36
+ Default: False.
37
+ two_stage_class_embed_share (bool): Whether to share the class embedding between the two stages.
38
+ rm_self_attn_layers (List[int]): The indices of the decoder layers to remove self-attention.
39
+ Default: None.
40
+ rm_detach (bool): Whether to detach the decoder output. Default: None.
41
+ embed_init_tgt (bool): If true, the target embedding is learnable. Otherwise, we will use
42
+ the topk encoder features as the target init. Default: True.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ encoder_cfg: Dict,
48
+ decoder_cfg: Dict,
49
+ mask_decoder_cfg: Dict = None,
50
+ num_queries: int = 900,
51
+ d_model: int = 256,
52
+ num_feature_levels: int = 4,
53
+ binary_query_selection: bool = False,
54
+ # init query (target)
55
+ learnable_tgt_init=True,
56
+ random_refpoints_xy=False,
57
+ # for two stage
58
+ two_stage_type: str = "standard",
59
+ two_stage_learn_wh: bool = False,
60
+ two_stage_keep_all_tokens: bool = False,
61
+ two_stage_bbox_embed_share: bool = False,
62
+ two_stage_class_embed_share: bool = False,
63
+ # evo of #anchors
64
+ rm_self_attn_layers: List[int] = None,
65
+ # for detach
66
+ rm_detach: bool = None,
67
+ with_encoder_out: bool = True,
68
+ ) -> None:
69
+ super().__init__()
70
+ self.binary_query_selection = binary_query_selection
71
+ self.num_queries = num_queries
72
+ self.num_feature_levels = num_feature_levels
73
+ self.rm_self_attn_layers = rm_self_attn_layers
74
+ self.d_model = d_model
75
+ self.two_stage_bbox_embed_share = two_stage_bbox_embed_share
76
+ self.two_stage_class_embed_share = two_stage_class_embed_share
77
+
78
+ if self.binary_query_selection:
79
+ self.binary_query_selection_layer = nn.Linear(d_model, 1)
80
+
81
+ # build encoder
82
+ self.encoder = build_encoder(encoder_cfg)
83
+
84
+ # build decoder
85
+ self.decoder = build_decoder(decoder_cfg)
86
+ self.num_decoder_layers = self.decoder.num_layers
87
+
88
+ # build sole mask decoder
89
+ if mask_decoder_cfg is not None:
90
+ self.mask_decoder = build_decoder(mask_decoder_cfg)
91
+ else:
92
+ self.mask_decoder = None
93
+ # level embedding
94
+ if num_feature_levels > 1:
95
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
96
+
97
+ # learnable target embedding
98
+ self.learnable_tgt_init = learnable_tgt_init
99
+ assert learnable_tgt_init, "why not learnable_tgt_init"
100
+
101
+ self.tgt_embed = nn.Embedding(num_queries, d_model)
102
+ nn.init.normal_(self.tgt_embed.weight.data)
103
+
104
+ # for two stage
105
+ # TODO: this part is really confusing
106
+ self.two_stage_type = two_stage_type
107
+ self.two_stage_learn_wh = two_stage_learn_wh
108
+ self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
109
+ assert two_stage_type in [
110
+ "no",
111
+ "standard",
112
+ ], "unknown param {} of two_stage_type".format(two_stage_type)
113
+ self.with_encoder_out = with_encoder_out
114
+ if two_stage_type == "standard":
115
+ # anchor selection at the output of encoder
116
+ if with_encoder_out:
117
+ self.enc_output = nn.Linear(d_model, d_model)
118
+ self.enc_output_norm = nn.LayerNorm(d_model)
119
+
120
+ if two_stage_learn_wh:
121
+ # import ipdb; ipdb.set_trace()
122
+ self.two_stage_wh_embedding = nn.Embedding(1, 2)
123
+ else:
124
+ self.two_stage_wh_embedding = None
125
+
126
+ elif two_stage_type == "no":
127
+ self.init_ref_points(
128
+ num_queries, random_refpoints_xy
129
+ ) # init self.refpoint_embed
130
+
131
+ self.enc_out_class_embed = None # this will be initialized outside of the model
132
+ self.enc_out_bbox_embed = None # this will be initialized outside of the model
133
+
134
+ # remove some self_attn_layers or rm_detach
135
+ self._reset_parameters()
136
+
137
+ self.rm_self_attn_layers = rm_self_attn_layers
138
+ if rm_self_attn_layers is not None:
139
+ # assert len(rm_self_attn_layers) == num_decoder_layers
140
+ print(
141
+ "Removing the self-attn in {} decoder layers".format(
142
+ rm_self_attn_layers
143
+ )
144
+ )
145
+ for lid, dec_layer in enumerate(self.decoder.layers):
146
+ if lid in rm_self_attn_layers:
147
+ dec_layer.rm_self_attn_modules()
148
+
149
+ self.rm_detach = rm_detach
150
+ if self.rm_detach:
151
+ assert isinstance(rm_detach, list)
152
+ assert any([i in ["enc_ref", "enc_tgt", "dec"] for i in rm_detach])
153
+ self.decoder.rm_detach = rm_detach
154
+
155
+ def _reset_parameters(self):
156
+ for p in self.parameters():
157
+ if p.dim() > 1:
158
+ nn.init.xavier_uniform_(p)
159
+ for m in self.modules():
160
+ if isinstance(m, MSDeformAttn):
161
+ m._reset_parameters()
162
+ if self.num_feature_levels > 1 and self.level_embed is not None:
163
+ nn.init.normal_(self.level_embed)
164
+
165
+ if self.two_stage_learn_wh:
166
+ nn.init.constant_(
167
+ self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))
168
+ )
169
+
170
+ def init_ref_points(self, num_queries: int, random_refpoints_xy: bool = False):
171
+ """Initialize learnable reference points for each query.
172
+
173
+ Args:
174
+ num_queries (int): number of queries
175
+ random_refpoints_xy (bool, optional): whether to init the refpoints randomly.
176
+ Defaults to False.
177
+ """
178
+ self.refpoint_embed = nn.Embedding(num_queries, 4)
179
+ if random_refpoints_xy:
180
+ self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
181
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
182
+ self.refpoint_embed.weight.data[:, :2]
183
+ )
184
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
185
+
186
+ def get_valid_ratio(self, mask):
187
+ _, H, W = mask.shape
188
+ valid_H = torch.sum(~mask[:, :, 0], 1)
189
+ valid_W = torch.sum(~mask[:, 0, :], 1)
190
+ valid_ratio_h = valid_H.float() / H
191
+ valid_ratio_w = valid_W.float() / W
192
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
193
+ return valid_ratio
194
+
195
+ def forward(
196
+ self,
197
+ src_flatten: torch.Tensor,
198
+ lvl_pos_embed_flatten: torch.Tensor,
199
+ level_start_index: List[int],
200
+ spatial_shapes: List[torch.Tensor],
201
+ valid_ratios: List[torch.Tensor],
202
+ mask_flatten: torch.Tensor,
203
+ prompt_type: str,
204
+ ) -> List[torch.Tensor]:
205
+ """Forward function."""
206
+ memory = self.encoder(
207
+ src_flatten,
208
+ pos=lvl_pos_embed_flatten,
209
+ level_start_index=level_start_index,
210
+ spatial_shapes=spatial_shapes,
211
+ valid_ratios=valid_ratios,
212
+ key_padding_mask=mask_flatten,
213
+ )
214
+ batch_size = src_flatten.shape[0]
215
+ crop_region_features = torch.zeros(batch_size, 1, self.d_model).to(
216
+ memory.device
217
+ )
218
+ if prompt_type == "fine_grained_prompt":
219
+ crop_region_features = (
220
+ self.fine_grained_prompt.weight[0]
221
+ .unsqueeze(0)
222
+ .unsqueeze(0)
223
+ .repeat(batch_size, 1, 1)
224
+ )
225
+ elif prompt_type == "coarse_grained_prompt":
226
+ crop_region_features = (
227
+ self.coarse_grained_prompt.weight[0]
228
+ .unsqueeze(0)
229
+ .unsqueeze(0)
230
+ .repeat(batch_size, 1, 1)
231
+ )
232
+ pad_mask = torch.ones(batch_size, 1).to(crop_region_features.device).bool()
233
+ self_attn_mask = torch.ones(batch_size, 1, 1).to(crop_region_features.device)
234
+ ref_dict = dict(
235
+ encoded_ref_feature=crop_region_features,
236
+ pad_mask=pad_mask,
237
+ self_attn_mask=self_attn_mask,
238
+ prompt_type="universal_prompt",
239
+ )
240
+
241
+ (
242
+ refpoint_embed,
243
+ tgt,
244
+ init_box_proposal,
245
+ ) = self.get_two_stage_proposal(memory, mask_flatten, spatial_shapes, ref_dict)
246
+
247
+ hs, references = self.decoder(
248
+ tgt=tgt.transpose(0, 1),
249
+ tgt_key_padding_mask=None,
250
+ memory=memory.transpose(0, 1),
251
+ memory_key_padding_mask=mask_flatten,
252
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
253
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
254
+ level_start_index=level_start_index,
255
+ spatial_shapes=spatial_shapes,
256
+ valid_ratios=valid_ratios,
257
+ tgt_mask=None,
258
+ # we ~ the mask . False means use the token; True means pad the token
259
+ )
260
+ hs_enc = ref_enc = None
261
+ return (
262
+ hs,
263
+ references,
264
+ ref_dict,
265
+ )
266
+
267
+ def get_two_stage_proposal(
268
+ self,
269
+ memory: torch.Tensor,
270
+ mask_flatten: torch.Tensor,
271
+ spatial_shapes: List[torch.Tensor],
272
+ ref_dict: Dict,
273
+ ) -> List[torch.Tensor]:
274
+ """Two stage proposal generation for decoder
275
+
276
+ Args:
277
+ memory (torch.Tensor): Image encoded feature. [bs, n, 256]
278
+ mask_flatten (torch.Tensor): Flattened mask. [bs, n]
279
+ spatial_shapes (List[torch.Tensor]): Spatial shapes of each feature map. [bs, num_levels, 2]
280
+ refpoint_embed_dn (torch.Tensor): Denosing refpoint embedding. [bs, num_dn_queries, 256]
281
+ tgt_dn (torch.Tensor): Denosing target embedding. [bs, num_dn_queries, 256]
282
+ ref_dict (Dict): A dict containing all kinds of reference image related features.
283
+ """
284
+ bs = memory.shape[0]
285
+ input_hw = None
286
+ output_memory, output_proposals = gen_encoder_output_proposals(
287
+ memory, mask_flatten, spatial_shapes, input_hw
288
+ )
289
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
290
+
291
+ if self.binary_query_selection: # Unused
292
+ topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1)
293
+ else:
294
+ if ref_dict is not None:
295
+ enc_outputs_class_unselected = self.enc_out_class_embed(
296
+ output_memory, ref_dict
297
+ ) # this is not a linear layer for prediction. But contrastive similaryity, shape [B, len_image, len_text]
298
+ else:
299
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
300
+ topk_logits = enc_outputs_class_unselected.max(-1)[
301
+ 0
302
+ ] # shape [B, len_image]
303
+ enc_outputs_coord_unselected = (
304
+ self.enc_out_bbox_embed(output_memory) + output_proposals
305
+ ) # (bs, \sum{hw}, 4) unsigmoid
306
+ topk = self.num_queries
307
+
308
+ try:
309
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
310
+ except:
311
+ raise ValueError(f"dadad {topk_logits.shape}")
312
+
313
+ # gather boxes
314
+ refpoint_embed_undetach = torch.gather(
315
+ enc_outputs_coord_unselected,
316
+ 1,
317
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
318
+ ) # unsigmoid
319
+ refpoint_embed_ = refpoint_embed_undetach.detach()
320
+ init_box_proposal = torch.gather(
321
+ output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
322
+ ).sigmoid() # sigmoid
323
+ # gather tgt
324
+ tgt_undetach = torch.gather(
325
+ output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
326
+ )
327
+ tgt_ = (
328
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
329
+ ) # nq, bs, d_model
330
+ refpoint_embed, tgt = refpoint_embed_, tgt_
331
+
332
+ return (
333
+ refpoint_embed,
334
+ tgt,
335
+ init_box_proposal,
336
+ )
detect_tools/upn/models/architecture/upn_model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, List, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from detect_tools.upn import ARCHITECTURES, build_architecture, build_backbone
9
+ from detect_tools.upn.models.module import (MLP, ContrastiveAssign, NestedTensor,
10
+ nested_tensor_from_tensor_list)
11
+ from detect_tools.upn.models.utils import inverse_sigmoid
12
+
13
+
14
+ class LayerNorm2d(nn.Module):
15
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(num_channels))
18
+ self.bias = nn.Parameter(torch.zeros(num_channels))
19
+ self.eps = eps
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ u = x.mean(1, keepdim=True)
23
+ s = (x - u).pow(2).mean(1, keepdim=True)
24
+ x = (x - u) / torch.sqrt(s + self.eps)
25
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
26
+ return x
27
+
28
+
29
+ @ARCHITECTURES.register_module()
30
+ class UPN(nn.Module):
31
+ """Implementation of UPN"""
32
+
33
+ def __init__(
34
+ self,
35
+ vision_backbone_cfg: Dict,
36
+ transformer_cfg: Dict,
37
+ num_queries: int,
38
+ dec_pred_class_embed_share=True,
39
+ dec_pred_bbox_embed_share=True,
40
+ decoder_sa_type="sa",
41
+ ):
42
+ super().__init__()
43
+ # build vision backbone
44
+ self.backbone = build_backbone(vision_backbone_cfg)
45
+ # build transformer
46
+ self.transformer = build_architecture(transformer_cfg)
47
+
48
+ self.hidden_dim = self.transformer.d_model
49
+
50
+ # for dn training
51
+ self.num_queries = num_queries
52
+ self.num_feature_levels = self.transformer.num_feature_levels
53
+
54
+ # prepare projection layer for vision feature
55
+ self.input_proj = self.prepare_vision_feature_projection_layer(
56
+ self.backbone,
57
+ self.transformer.num_feature_levels,
58
+ self.hidden_dim,
59
+ self.transformer.two_stage_type,
60
+ )
61
+ # prepare prediction head
62
+ self.prepare_prediction_head(
63
+ dec_pred_class_embed_share,
64
+ dec_pred_bbox_embed_share,
65
+ self.hidden_dim,
66
+ self.transformer.num_decoder_layers,
67
+ )
68
+
69
+ self.decoder_sa_type = decoder_sa_type
70
+ assert decoder_sa_type in ["sa", "ca_label", "ca_content"]
71
+ # self.replace_sa_with_double_ca = replace_sa_with_double_ca
72
+
73
+ for layer in self.transformer.decoder.layers:
74
+ layer.label_embedding = None
75
+ self.label_embedding = None
76
+
77
+ # build a unversal token
78
+ self.transformer.fine_grained_prompt = nn.Embedding(1, self.hidden_dim)
79
+ self.transformer.coarse_grained_prompt = nn.Embedding(1, self.hidden_dim)
80
+
81
+ self._reset_parameters()
82
+
83
+ def forward(self, samples: NestedTensor, prompt_type: str = None) -> Dict:
84
+ """Foward function"""
85
+ self.device = samples.device
86
+
87
+ (
88
+ src_flatten,
89
+ lvl_pos_embed_flatten,
90
+ level_start_index,
91
+ spatial_shapes,
92
+ valid_ratios,
93
+ mask_flatten,
94
+ ) = self.forward_backbone_encoder(samples)
95
+
96
+ (
97
+ hs,
98
+ reference,
99
+ ref_dict,
100
+ ) = self.transformer(
101
+ src_flatten,
102
+ lvl_pos_embed_flatten,
103
+ level_start_index,
104
+ spatial_shapes,
105
+ valid_ratios,
106
+ mask_flatten,
107
+ prompt_type,
108
+ )
109
+
110
+ # deformable-detr-line anchor update
111
+ outputs_coord_list = []
112
+ outputs_class = []
113
+
114
+ for layer_idx, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
115
+ zip(reference[:-1], self.bbox_embed, hs)
116
+ ):
117
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
118
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
119
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
120
+ outputs_coord_list.append(layer_outputs_unsig)
121
+
122
+ outputs_coord_list = torch.stack(outputs_coord_list)
123
+
124
+ if ref_dict is None:
125
+ # build a mock outputs_class for mask_dn training
126
+ outputs_class = torch.zeros(
127
+ outputs_coord_list.shape[0],
128
+ outputs_coord_list.shape[1],
129
+ outputs_coord_list.shape[2],
130
+ self.hidden_dim,
131
+ )
132
+ else:
133
+ outputs_class = torch.stack(
134
+ [
135
+ layer_cls_embed(layer_hs, ref_dict)
136
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
137
+ ]
138
+ )
139
+
140
+ out = {
141
+ "pred_logits": outputs_class[-1],
142
+ "pred_boxes": outputs_coord_list[-1],
143
+ }
144
+ out["ref_dict"] = ref_dict
145
+ return out
146
+
147
+ def forward_backbone_encoder(self, samples: NestedTensor) -> Tuple:
148
+ # pass through backbone
149
+ if isinstance(samples, (list, torch.Tensor)):
150
+ samples = nested_tensor_from_tensor_list(samples)
151
+ features, poss = self.backbone(samples)
152
+ # project features
153
+ srcs = []
154
+ masks = []
155
+ for l, feat in enumerate(features):
156
+ src, mask = feat.decompose()
157
+ srcs.append(self.input_proj[l](src)) # downsample the feature map to 256
158
+ masks.append(mask)
159
+ assert mask is not None
160
+
161
+ if self.num_feature_levels > len(
162
+ srcs
163
+ ): # add more feature levels by downsampling the last feature map
164
+ _len_srcs = len(srcs)
165
+ for l in range(_len_srcs, self.num_feature_levels):
166
+ if l == _len_srcs:
167
+ src = self.input_proj[l](features[-1].tensors)
168
+ else:
169
+ src = self.input_proj[l](srcs[-1])
170
+ m = samples.mask
171
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
172
+ torch.bool
173
+ )[0]
174
+ pos_l = self.backbone.forward_pos_embed_only(
175
+ NestedTensor(src, mask)
176
+ ).to(src.dtype)
177
+ srcs.append(src)
178
+ masks.append(mask)
179
+ poss.append(pos_l)
180
+
181
+ # prepare input for encoder with the following steps:
182
+ # 1. flatten the feature maps and masks
183
+ # 2. Add positional embedding and level embedding
184
+ # 3. Calculate the valid ratio of each feature map based on the mask
185
+ src_flatten = []
186
+ mask_flatten = []
187
+ lvl_pos_embed_flatten = []
188
+ spatial_shapes = []
189
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, poss)):
190
+ bs, c, h, w = src.shape
191
+ spatial_shape = (h, w)
192
+ spatial_shapes.append(spatial_shape)
193
+
194
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
195
+ mask = mask.flatten(1) # bs, hw
196
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
197
+ if self.num_feature_levels > 1 and self.transformer.level_embed is not None:
198
+ lvl_pos_embed = pos_embed + self.transformer.level_embed[lvl].view(
199
+ 1, 1, -1
200
+ )
201
+ else:
202
+ lvl_pos_embed = pos_embed
203
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
204
+ src_flatten.append(src)
205
+ mask_flatten.append(mask)
206
+ src_flatten = torch.cat(src_flatten, 1)
207
+ mask_flatten = torch.cat(mask_flatten, 1)
208
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
209
+ spatial_shapes = torch.as_tensor(
210
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
211
+ )
212
+ level_start_index = torch.cat(
213
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
214
+ )
215
+ valid_ratios = torch.stack(
216
+ [self.transformer.get_valid_ratio(m) for m in masks], 1
217
+ )
218
+
219
+ return (
220
+ src_flatten,
221
+ lvl_pos_embed_flatten,
222
+ level_start_index,
223
+ spatial_shapes,
224
+ valid_ratios,
225
+ mask_flatten,
226
+ )
227
+
228
+ def prepare_vision_feature_projection_layer(
229
+ self,
230
+ backbone: nn.Module,
231
+ num_feature_levels: int,
232
+ hidden_dim: int,
233
+ two_stage_type: str,
234
+ ) -> nn.ModuleList:
235
+ """Prepare projection layer to map backbone feature to hidden dim.
236
+
237
+ Args:
238
+ backbone (nn.Module): Backbone.
239
+ num_feature_levels (int): Number of feature levels.
240
+ hidden_dim (int): Hidden dim.
241
+ two_stage_type (str): Type of two stage.
242
+
243
+ Returns:
244
+ nn.ModuleList: Projection layer.
245
+ """
246
+ if num_feature_levels > 1:
247
+ num_backbone_outs = len(backbone.num_channels)
248
+ input_proj_list = []
249
+ for _ in range(num_backbone_outs):
250
+ in_channels = backbone.num_channels[_]
251
+ input_proj_list.append(
252
+ nn.Sequential(
253
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
254
+ nn.GroupNorm(32, hidden_dim),
255
+ )
256
+ )
257
+ for _ in range(num_feature_levels - num_backbone_outs):
258
+ input_proj_list.append(
259
+ nn.Sequential(
260
+ nn.Conv2d(
261
+ in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
262
+ ),
263
+ nn.GroupNorm(32, hidden_dim),
264
+ )
265
+ )
266
+ in_channels = hidden_dim
267
+ input_proj = nn.ModuleList(input_proj_list)
268
+ else:
269
+ assert (
270
+ two_stage_type == "no"
271
+ ), "two_stage_type should be no if num_feature_levels=1 !!!"
272
+ input_proj = nn.ModuleList(
273
+ [
274
+ nn.Sequential(
275
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
276
+ nn.GroupNorm(32, hidden_dim),
277
+ )
278
+ ]
279
+ )
280
+ return input_proj
281
+
282
+ def prepare_prediction_head(
283
+ self,
284
+ dec_pred_class_embed_share: bool,
285
+ dec_pred_bbox_embed_share: bool,
286
+ hidden_dim: int,
287
+ num_decoder_layers: int,
288
+ ) -> Union[nn.ModuleList, nn.ModuleList]:
289
+ """Prepare prediction head. Including class embed and bbox embed.
290
+
291
+ Args:
292
+ dec_pred_class_embed_share (bool): Whether to share class embed for all decoder layers.
293
+ dec_pred_bbox_embed_share (bool): Whether to share bbox embed for all decoder layers.
294
+ im (int): Hidden dim.
295
+ num_decoder_layers (int): Number of decoder layers.
296
+
297
+ """
298
+ _class_embed = ContrastiveAssign()
299
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
300
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
301
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
302
+ if dec_pred_bbox_embed_share:
303
+ box_embed_layerlist = [_bbox_embed for _ in range(num_decoder_layers)]
304
+ else:
305
+ box_embed_layerlist = [
306
+ copy.deepcopy(_bbox_embed) for i in range(num_decoder_layers)
307
+ ]
308
+ if dec_pred_class_embed_share:
309
+ class_embed_layerlist = [_class_embed for i in range(num_decoder_layers)]
310
+ else:
311
+ class_embed_layerlist = [
312
+ copy.deepcopy(_class_embed) for i in range(num_decoder_layers)
313
+ ]
314
+ bbox_embed = nn.ModuleList(box_embed_layerlist)
315
+ class_embed = nn.ModuleList(class_embed_layerlist)
316
+ self.bbox_embed = bbox_embed
317
+ self.class_embed = class_embed
318
+
319
+ # iniitalize bbox embed and class embed in transformer
320
+ self.transformer.decoder.bbox_embed = bbox_embed
321
+ self.transformer.decoder.class_embed = class_embed
322
+
323
+ if self.transformer.two_stage_type != "no":
324
+ if self.transformer.two_stage_bbox_embed_share:
325
+ assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
326
+ self.transformer.enc_out_bbox_embed = _bbox_embed
327
+ else:
328
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
329
+
330
+ if self.transformer.two_stage_class_embed_share:
331
+ assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
332
+ self.transformer.enc_out_class_embed = _class_embed
333
+ else:
334
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
335
+
336
+ self.refpoint_embed = None
337
+
338
+ def _reset_parameters(self):
339
+ # init input_proj
340
+ for proj in self.input_proj:
341
+
342
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
343
+ nn.init.constant_(proj[0].bias, 0)
detect_tools/upn/models/backbone/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .swin import SwinTransformer
2
+ from .wrapper import SwinWrapper
3
+
4
+ __all__ = ["SwinWrapper", "SwinTransformer"]
detect_tools/upn/models/backbone/swin.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint as checkpoint
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+
10
+ from detect_tools.upn import BACKBONES
11
+ from detect_tools.upn.models.module import NestedTensor
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ """Multilayer perceptron."""
16
+
17
+ def __init__(
18
+ self,
19
+ in_features,
20
+ hidden_features=None,
21
+ out_features=None,
22
+ act_layer=nn.GELU,
23
+ drop=0.0,
24
+ ):
25
+ super().__init__()
26
+ out_features = out_features or in_features
27
+ hidden_features = hidden_features or in_features
28
+ self.fc1 = nn.Linear(in_features, hidden_features)
29
+ self.act = act_layer()
30
+ self.fc2 = nn.Linear(hidden_features, out_features)
31
+ self.drop = nn.Dropout(drop)
32
+
33
+ def forward(self, x):
34
+ x = self.fc1(x)
35
+ x = self.act(x)
36
+ x = self.drop(x)
37
+ x = self.fc2(x)
38
+ x = self.drop(x)
39
+ return x
40
+
41
+
42
+ def window_partition(x, window_size):
43
+ """
44
+ Args:
45
+ x: (B, H, W, C)
46
+ window_size (int): window size
47
+ Returns:
48
+ windows: (num_windows*B, window_size, window_size, C)
49
+ """
50
+ B, H, W, C = x.shape
51
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
52
+ windows = (
53
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
54
+ )
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(
70
+ B, H // window_size, W // window_size, window_size, window_size, -1
71
+ )
72
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
73
+ return x
74
+
75
+
76
+ class WindowAttention(nn.Module):
77
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
78
+ It supports both of shifted and non-shifted window.
79
+ Args:
80
+ dim (int): Number of input channels.
81
+ window_size (tuple[int]): The height and width of the window.
82
+ num_heads (int): Number of attention heads.
83
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
84
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
85
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
86
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ window_size,
93
+ num_heads,
94
+ qkv_bias=True,
95
+ qk_scale=None,
96
+ attn_drop=0.0,
97
+ proj_drop=0.0,
98
+ ):
99
+
100
+ super().__init__()
101
+ self.dim = dim
102
+ self.window_size = window_size # Wh, Ww
103
+ self.num_heads = num_heads
104
+ head_dim = dim // num_heads
105
+ self.scale = qk_scale or head_dim**-0.5
106
+
107
+ # define a parameter table of relative position bias
108
+ self.relative_position_bias_table = nn.Parameter(
109
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
110
+ ) # 2*Wh-1 * 2*Ww-1, nH
111
+
112
+ # get pair-wise relative position index for each token inside the window
113
+ coords_h = torch.arange(self.window_size[0])
114
+ coords_w = torch.arange(self.window_size[1])
115
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
116
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
117
+ relative_coords = (
118
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
119
+ ) # 2, Wh*Ww, Wh*Ww
120
+ relative_coords = relative_coords.permute(
121
+ 1, 2, 0
122
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
123
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
124
+ relative_coords[:, :, 1] += self.window_size[1] - 1
125
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
126
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
127
+ self.register_buffer("relative_position_index", relative_position_index)
128
+
129
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+
134
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
135
+ self.softmax = nn.Softmax(dim=-1)
136
+
137
+ def forward(self, x, mask=None):
138
+ """Forward function.
139
+ Args:
140
+ x: input features with shape of (num_windows*B, N, C)
141
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
142
+ """
143
+ B_, N, C = x.shape
144
+ qkv = (
145
+ self.qkv(x)
146
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
147
+ .permute(2, 0, 3, 1, 4)
148
+ )
149
+ q, k, v = (
150
+ qkv[0],
151
+ qkv[1],
152
+ qkv[2],
153
+ ) # make torchscript happy (cannot use tensor as tuple)
154
+
155
+ q = q * self.scale
156
+ attn = q @ k.transpose(-2, -1)
157
+
158
+ relative_position_bias = self.relative_position_bias_table[
159
+ self.relative_position_index.view(-1)
160
+ ].view(
161
+ self.window_size[0] * self.window_size[1],
162
+ self.window_size[0] * self.window_size[1],
163
+ -1,
164
+ ) # Wh*Ww,Wh*Ww,nH
165
+ relative_position_bias = relative_position_bias.permute(
166
+ 2, 0, 1
167
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
168
+ attn = attn + relative_position_bias.unsqueeze(0)
169
+
170
+ if mask is not None:
171
+ nW = mask.shape[0]
172
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
173
+ 1
174
+ ).unsqueeze(0)
175
+ attn = attn.view(-1, self.num_heads, N, N)
176
+ attn = self.softmax(attn)
177
+ else:
178
+ attn = self.softmax(attn)
179
+
180
+ attn = self.attn_drop(attn)
181
+
182
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
183
+ x = self.proj(x)
184
+ x = self.proj_drop(x)
185
+ return x
186
+
187
+
188
+ class SwinTransformerBlock(nn.Module):
189
+ """Swin Transformer Block.
190
+ Args:
191
+ dim (int): Number of input channels.
192
+ num_heads (int): Number of attention heads.
193
+ window_size (int): Window size.
194
+ shift_size (int): Shift size for SW-MSA.
195
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
196
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
197
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
198
+ drop (float, optional): Dropout rate. Default: 0.0
199
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
200
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
201
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
202
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ dim,
208
+ num_heads,
209
+ window_size=7,
210
+ shift_size=0,
211
+ mlp_ratio=4.0,
212
+ qkv_bias=True,
213
+ qk_scale=None,
214
+ drop=0.0,
215
+ attn_drop=0.0,
216
+ drop_path=0.0,
217
+ act_layer=nn.GELU,
218
+ norm_layer=nn.LayerNorm,
219
+ ):
220
+ super().__init__()
221
+ self.dim = dim
222
+ self.num_heads = num_heads
223
+ self.window_size = window_size
224
+ self.shift_size = shift_size
225
+ self.mlp_ratio = mlp_ratio
226
+ assert (
227
+ 0 <= self.shift_size < self.window_size
228
+ ), "shift_size must in 0-window_size"
229
+
230
+ self.norm1 = norm_layer(dim)
231
+ self.attn = WindowAttention(
232
+ dim,
233
+ window_size=to_2tuple(self.window_size),
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ qk_scale=qk_scale,
237
+ attn_drop=attn_drop,
238
+ proj_drop=drop,
239
+ )
240
+
241
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
242
+ self.norm2 = norm_layer(dim)
243
+ mlp_hidden_dim = int(dim * mlp_ratio)
244
+ self.mlp = Mlp(
245
+ in_features=dim,
246
+ hidden_features=mlp_hidden_dim,
247
+ act_layer=act_layer,
248
+ drop=drop,
249
+ )
250
+
251
+ self.H = None
252
+ self.W = None
253
+
254
+ def forward(self, x, mask_matrix):
255
+ """Forward function.
256
+ Args:
257
+ x: Input feature, tensor size (B, H*W, C).
258
+ H, W: Spatial resolution of the input feature.
259
+ mask_matrix: Attention mask for cyclic shift.
260
+ """
261
+ B, L, C = x.shape
262
+ H, W = self.H, self.W
263
+ assert L == H * W, "input feature has wrong size"
264
+
265
+ shortcut = x
266
+ x = self.norm1(x)
267
+ x = x.view(B, H, W, C)
268
+
269
+ # pad feature maps to multiples of window size
270
+ pad_l = pad_t = 0
271
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
272
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
273
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
274
+ _, Hp, Wp, _ = x.shape
275
+
276
+ # cyclic shift
277
+ if self.shift_size > 0:
278
+ shifted_x = torch.roll(
279
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
280
+ )
281
+ attn_mask = mask_matrix
282
+ else:
283
+ shifted_x = x
284
+ attn_mask = None
285
+
286
+ # partition windows
287
+ x_windows = window_partition(
288
+ shifted_x, self.window_size
289
+ ) # nW*B, window_size, window_size, C
290
+ x_windows = x_windows.view(
291
+ -1, self.window_size * self.window_size, C
292
+ ) # nW*B, window_size*window_size, C
293
+
294
+ # W-MSA/SW-MSA
295
+ attn_windows = self.attn(
296
+ x_windows, mask=attn_mask
297
+ ) # nW*B, window_size*window_size, C
298
+
299
+ # merge windows
300
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
301
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
302
+
303
+ # reverse cyclic shift
304
+ if self.shift_size > 0:
305
+ x = torch.roll(
306
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
307
+ )
308
+ else:
309
+ x = shifted_x
310
+
311
+ if pad_r > 0 or pad_b > 0:
312
+ x = x[:, :H, :W, :].contiguous()
313
+
314
+ x = x.view(B, H * W, C)
315
+
316
+ # FFN
317
+ x = shortcut + self.drop_path(x)
318
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
319
+
320
+ return x
321
+
322
+
323
+ class PatchMerging(nn.Module):
324
+ """Patch Merging Layer
325
+ Args:
326
+ dim (int): Number of input channels.
327
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
328
+ """
329
+
330
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
331
+ super().__init__()
332
+ self.dim = dim
333
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
334
+ self.norm = norm_layer(4 * dim)
335
+
336
+ def forward(self, x, H, W):
337
+ """Forward function.
338
+ Args:
339
+ x: Input feature, tensor size (B, H*W, C).
340
+ H, W: Spatial resolution of the input feature.
341
+ """
342
+ B, L, C = x.shape
343
+ assert L == H * W, "input feature has wrong size"
344
+
345
+ x = x.view(B, H, W, C)
346
+
347
+ # padding
348
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
349
+ if pad_input:
350
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
351
+
352
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
353
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
354
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
355
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
356
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
357
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
358
+
359
+ x = self.norm(x)
360
+ x = self.reduction(x)
361
+
362
+ return x
363
+
364
+
365
+ class BasicLayer(nn.Module):
366
+ """A basic Swin Transformer layer for one stage.
367
+ Args:
368
+ dim (int): Number of feature channels
369
+ depth (int): Depths of this stage.
370
+ num_heads (int): Number of attention head.
371
+ window_size (int): Local window size. Default: 7.
372
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
373
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
374
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
375
+ drop (float, optional): Dropout rate. Default: 0.0
376
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
377
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
378
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
379
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
380
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ dim,
386
+ depth,
387
+ num_heads,
388
+ window_size=7,
389
+ mlp_ratio=4.0,
390
+ qkv_bias=True,
391
+ qk_scale=None,
392
+ drop=0.0,
393
+ attn_drop=0.0,
394
+ drop_path=0.0,
395
+ norm_layer=nn.LayerNorm,
396
+ downsample=None,
397
+ use_checkpoint=False,
398
+ ):
399
+ super().__init__()
400
+ self.window_size = window_size
401
+ self.shift_size = window_size // 2
402
+ self.depth = depth
403
+ self.use_checkpoint = use_checkpoint
404
+
405
+ # build blocks
406
+ self.blocks = nn.ModuleList(
407
+ [
408
+ SwinTransformerBlock(
409
+ dim=dim,
410
+ num_heads=num_heads,
411
+ window_size=window_size,
412
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
413
+ mlp_ratio=mlp_ratio,
414
+ qkv_bias=qkv_bias,
415
+ qk_scale=qk_scale,
416
+ drop=drop,
417
+ attn_drop=attn_drop,
418
+ drop_path=(
419
+ drop_path[i] if isinstance(drop_path, list) else drop_path
420
+ ),
421
+ norm_layer=norm_layer,
422
+ )
423
+ for i in range(depth)
424
+ ]
425
+ )
426
+
427
+ # patch merging layer
428
+ if downsample is not None:
429
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
430
+ else:
431
+ self.downsample = None
432
+
433
+ def forward(self, x, H, W):
434
+ """Forward function.
435
+ Args:
436
+ x: Input feature, tensor size (B, H*W, C).
437
+ H, W: Spatial resolution of the input feature.
438
+ """
439
+
440
+ # calculate attention mask for SW-MSA
441
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
442
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
443
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
444
+ h_slices = (
445
+ slice(0, -self.window_size),
446
+ slice(-self.window_size, -self.shift_size),
447
+ slice(-self.shift_size, None),
448
+ )
449
+ w_slices = (
450
+ slice(0, -self.window_size),
451
+ slice(-self.window_size, -self.shift_size),
452
+ slice(-self.shift_size, None),
453
+ )
454
+ cnt = 0
455
+ for h in h_slices:
456
+ for w in w_slices:
457
+ img_mask[:, h, w, :] = cnt
458
+ cnt += 1
459
+
460
+ mask_windows = window_partition(
461
+ img_mask, self.window_size
462
+ ) # nW, window_size, window_size, 1
463
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
464
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
465
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
466
+ attn_mask == 0, float(0.0)
467
+ )
468
+
469
+ for blk in self.blocks:
470
+ blk.H, blk.W = H, W
471
+ if self.use_checkpoint:
472
+ x = checkpoint.checkpoint(blk, x, attn_mask)
473
+ else:
474
+ x = blk(x, attn_mask)
475
+ if self.downsample is not None:
476
+ x_down = self.downsample(x, H, W)
477
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
478
+ return x, H, W, x_down, Wh, Ww
479
+ else:
480
+ return x, H, W, x, H, W
481
+
482
+
483
+ class PatchEmbed(nn.Module):
484
+ """Image to Patch Embedding
485
+ Args:
486
+ patch_size (int): Patch token size. Default: 4.
487
+ in_chans (int): Number of input image channels. Default: 3.
488
+ embed_dim (int): Number of linear projection output channels. Default: 96.
489
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
490
+ """
491
+
492
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
493
+ super().__init__()
494
+ patch_size = to_2tuple(patch_size)
495
+ self.patch_size = patch_size
496
+
497
+ self.in_chans = in_chans
498
+ self.embed_dim = embed_dim
499
+
500
+ self.proj = nn.Conv2d(
501
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
502
+ )
503
+ if norm_layer is not None:
504
+ self.norm = norm_layer(embed_dim)
505
+ else:
506
+ self.norm = None
507
+
508
+ def forward(self, x):
509
+ """Forward function."""
510
+ # padding
511
+ _, _, H, W = x.size()
512
+ if W % self.patch_size[1] != 0:
513
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
514
+ if H % self.patch_size[0] != 0:
515
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
516
+
517
+ x = self.proj(x) # B C Wh Ww
518
+ if self.norm is not None:
519
+ Wh, Ww = x.size(2), x.size(3)
520
+ x = x.flatten(2).transpose(1, 2)
521
+ x = self.norm(x)
522
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
523
+
524
+ return x
525
+
526
+
527
+ @BACKBONES.register_module()
528
+ class SwinTransformer(nn.Module):
529
+ """Swin Transformer backbone.
530
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
531
+ https://arxiv.org/pdf/2103.14030
532
+ Args:
533
+ pretrain_img_size (int): Input image size for training the pretrained model,
534
+ used in absolute postion embedding. Default 224.
535
+ patch_size (int | tuple(int)): Patch size. Default: 4.
536
+ in_chans (int): Number of input image channels. Default: 3.
537
+ embed_dim (int): Number of linear projection output channels. Default: 96.
538
+ depths (tuple[int]): Depths of each Swin Transformer stage.
539
+ num_heads (tuple[int]): Number of attention head of each stage.
540
+ window_size (int): Window size. Default: 7.
541
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
542
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
543
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
544
+ drop_rate (float): Dropout rate.
545
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
546
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
547
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
548
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
549
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
550
+ out_indices (Sequence[int]): Output from which stages.
551
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
552
+ -1 means not freezing any parameters.
553
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
554
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
555
+ """
556
+
557
+ def __init__(
558
+ self,
559
+ pretrain_img_size=224,
560
+ patch_size=4,
561
+ in_chans=3,
562
+ embed_dim=96,
563
+ depths=[2, 2, 6, 2],
564
+ num_heads=[3, 6, 12, 24],
565
+ window_size=7,
566
+ mlp_ratio=4.0,
567
+ qkv_bias=True,
568
+ qk_scale=None,
569
+ drop_rate=0.0,
570
+ attn_drop_rate=0.0,
571
+ drop_path_rate=0.2,
572
+ norm_layer=nn.LayerNorm,
573
+ ape=False,
574
+ patch_norm=True,
575
+ out_indices=(0, 1, 2, 3),
576
+ frozen_stages=-1,
577
+ dilation=False,
578
+ use_checkpoint=False,
579
+ ):
580
+ super().__init__()
581
+
582
+ self.pretrain_img_size = pretrain_img_size
583
+ self.num_layers = len(depths)
584
+ self.embed_dim = embed_dim
585
+ self.ape = ape
586
+ self.patch_norm = patch_norm
587
+ self.out_indices = out_indices
588
+ self.frozen_stages = frozen_stages
589
+ self.dilation = dilation
590
+
591
+ if use_checkpoint:
592
+ print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
593
+
594
+ # split image into non-overlapping patches
595
+ self.patch_embed = PatchEmbed(
596
+ patch_size=patch_size,
597
+ in_chans=in_chans,
598
+ embed_dim=embed_dim,
599
+ norm_layer=norm_layer if self.patch_norm else None,
600
+ )
601
+
602
+ # absolute position embedding
603
+ if self.ape:
604
+ pretrain_img_size = to_2tuple(pretrain_img_size)
605
+ patch_size = to_2tuple(patch_size)
606
+ patches_resolution = [
607
+ pretrain_img_size[0] // patch_size[0],
608
+ pretrain_img_size[1] // patch_size[1],
609
+ ]
610
+
611
+ self.absolute_pos_embed = nn.Parameter(
612
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
613
+ )
614
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
615
+
616
+ self.pos_drop = nn.Dropout(p=drop_rate)
617
+
618
+ # stochastic depth
619
+ dpr = [
620
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
621
+ ] # stochastic depth decay rule
622
+
623
+ # build layers
624
+ self.layers = nn.ModuleList()
625
+ # prepare downsample list
626
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
627
+ downsamplelist[-1] = None
628
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
629
+ if self.dilation:
630
+ downsamplelist[-2] = None
631
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
632
+ for i_layer in range(self.num_layers):
633
+ layer = BasicLayer(
634
+ # dim=int(embed_dim * 2 ** i_layer),
635
+ dim=num_features[i_layer],
636
+ depth=depths[i_layer],
637
+ num_heads=num_heads[i_layer],
638
+ window_size=window_size,
639
+ mlp_ratio=mlp_ratio,
640
+ qkv_bias=qkv_bias,
641
+ qk_scale=qk_scale,
642
+ drop=drop_rate,
643
+ attn_drop=attn_drop_rate,
644
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
645
+ norm_layer=norm_layer,
646
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
647
+ downsample=downsamplelist[i_layer],
648
+ use_checkpoint=use_checkpoint,
649
+ )
650
+ self.layers.append(layer)
651
+
652
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
653
+ self.num_features = num_features
654
+
655
+ # add a norm layer for each output
656
+ for i_layer in out_indices:
657
+ layer = norm_layer(num_features[i_layer])
658
+ layer_name = f"norm{i_layer}"
659
+ self.add_module(layer_name, layer)
660
+
661
+ self._freeze_stages()
662
+
663
+ def _freeze_stages(self):
664
+ if self.frozen_stages >= 0:
665
+ self.patch_embed.eval()
666
+ for param in self.patch_embed.parameters():
667
+ param.requires_grad = False
668
+
669
+ if self.frozen_stages >= 1 and self.ape:
670
+ self.absolute_pos_embed.requires_grad = False
671
+
672
+ if self.frozen_stages >= 2:
673
+ self.pos_drop.eval()
674
+ for i in range(0, self.frozen_stages - 1):
675
+ m = self.layers[i]
676
+ m.eval()
677
+ for param in m.parameters():
678
+ param.requires_grad = False
679
+
680
+ def forward_raw(self, x: torch.Tensor) -> List[torch.Tensor]:
681
+ """Forward function."""
682
+ x = self.patch_embed(x)
683
+
684
+ Wh, Ww = x.size(2), x.size(3)
685
+ if self.ape:
686
+ # interpolate the position embedding to the corresponding size
687
+ absolute_pos_embed = F.interpolate(
688
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
689
+ )
690
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
691
+ else:
692
+ x = x.flatten(2).transpose(1, 2)
693
+ x = self.pos_drop(x)
694
+
695
+ outs = []
696
+ for i in range(self.num_layers):
697
+ layer = self.layers[i]
698
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
699
+ # import ipdb; ipdb.set_trace()
700
+
701
+ if i in self.out_indices:
702
+ norm_layer = getattr(self, f"norm{i}")
703
+ x_out = norm_layer(x_out)
704
+
705
+ out = (
706
+ x_out.view(-1, H, W, self.num_features[i])
707
+ .permute(0, 3, 1, 2)
708
+ .contiguous()
709
+ )
710
+ outs.append(out)
711
+
712
+ return tuple(outs)
713
+
714
+ def forward(self, tensor_list: NestedTensor) -> Dict:
715
+ """Forward function.
716
+
717
+ Args:
718
+ tensor_list (NestedTensor): NestedTensor object containing tensors and masks.
719
+
720
+ Returns:
721
+ Dict: Dict containing output tensors. The structure is as follows.
722
+ - 0: NestedTensor from stage 0.
723
+ - 1: NestedTensor from stage 1.
724
+ - 2: NestedTensor from stage 2.
725
+ - 3: NestedTensor from stage 3.
726
+ """
727
+ x = tensor_list.tensors
728
+
729
+ x = self.patch_embed(x)
730
+
731
+ Wh, Ww = x.size(2), x.size(3)
732
+ if self.ape:
733
+ # interpolate the position embedding to the corresponding size
734
+ absolute_pos_embed = F.interpolate(
735
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
736
+ )
737
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
738
+ else:
739
+ x = x.flatten(2).transpose(1, 2)
740
+ x = self.pos_drop(x)
741
+
742
+ outs = []
743
+ for i in range(self.num_layers):
744
+ layer = self.layers[i]
745
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
746
+
747
+ if i in self.out_indices:
748
+ norm_layer = getattr(self, f"norm{i}")
749
+ x_out = norm_layer(x_out)
750
+
751
+ out = (
752
+ x_out.view(-1, H, W, self.num_features[i])
753
+ .permute(0, 3, 1, 2)
754
+ .contiguous()
755
+ )
756
+ outs.append(out)
757
+
758
+ # collect for nesttensors
759
+ outs_dict = {}
760
+ for idx, out_i in enumerate(outs):
761
+ m = tensor_list.mask
762
+ assert m is not None
763
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[
764
+ 0
765
+ ]
766
+ outs_dict[idx] = NestedTensor(out_i, mask)
767
+
768
+ return outs_dict
769
+
770
+ def train(self, mode=True):
771
+ """Convert the model into training mode while keep layers freezed."""
772
+ super(SwinTransformer, self).train(mode)
773
+ self._freeze_stages()
774
+
775
+
776
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
777
+ assert modelname in [
778
+ "swin_T_224_1k",
779
+ "swin_B_224_22k",
780
+ "swin_B_384_22k",
781
+ "swin_L_224_22k",
782
+ "swin_L_384_22k",
783
+ ]
784
+
785
+ model_para_dict = {
786
+ "swin_T_224_1k": dict(
787
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
788
+ ),
789
+ "swin_B_224_22k": dict(
790
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
791
+ ),
792
+ "swin_B_384_22k": dict(
793
+ embed_dim=128,
794
+ depths=[2, 2, 18, 2],
795
+ num_heads=[4, 8, 16, 32],
796
+ window_size=12,
797
+ ),
798
+ "swin_L_224_22k": dict(
799
+ embed_dim=192,
800
+ depths=[2, 2, 18, 2],
801
+ num_heads=[6, 12, 24, 48],
802
+ window_size=7,
803
+ ),
804
+ "swin_L_384_22k": dict(
805
+ embed_dim=192,
806
+ depths=[2, 2, 18, 2],
807
+ num_heads=[6, 12, 24, 48],
808
+ window_size=12,
809
+ ),
810
+ }
811
+ kw_cgf = model_para_dict[modelname]
812
+ kw_cgf.update(kw)
813
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
814
+ return model
detect_tools/upn/models/backbone/wrapper.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding
7
+ from detect_tools.upn.models.module import NestedTensor
8
+ from detect_tools.upn.models.utils import clean_state_dict
9
+
10
+
11
+ class FrozenBatchNorm2d(torch.nn.Module):
12
+ """
13
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
14
+
15
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
16
+ without which any other models than torchvision.models.resnet[18,34,50,101]
17
+ produce nans.
18
+ """
19
+
20
+ def __init__(self, n):
21
+ super(FrozenBatchNorm2d, self).__init__()
22
+ self.register_buffer("weight", torch.ones(n))
23
+ self.register_buffer("bias", torch.zeros(n))
24
+ self.register_buffer("running_mean", torch.zeros(n))
25
+ self.register_buffer("running_var", torch.ones(n))
26
+
27
+ def _load_from_state_dict(
28
+ self,
29
+ state_dict,
30
+ prefix,
31
+ local_metadata,
32
+ strict,
33
+ missing_keys,
34
+ unexpected_keys,
35
+ error_msgs,
36
+ ):
37
+ num_batches_tracked_key = prefix + "num_batches_tracked"
38
+ if num_batches_tracked_key in state_dict:
39
+ del state_dict[num_batches_tracked_key]
40
+
41
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
42
+ state_dict,
43
+ prefix,
44
+ local_metadata,
45
+ strict,
46
+ missing_keys,
47
+ unexpected_keys,
48
+ error_msgs,
49
+ )
50
+
51
+ def forward(self, x):
52
+ # move reshapes to the beginning
53
+ # to make it fuser-friendly
54
+ w = self.weight.reshape(1, -1, 1, 1)
55
+ b = self.bias.reshape(1, -1, 1, 1)
56
+ rv = self.running_var.reshape(1, -1, 1, 1)
57
+ rm = self.running_mean.reshape(1, -1, 1, 1)
58
+ eps = 1e-5
59
+ scale = w * (rv + eps).rsqrt()
60
+ bias = b - rm * scale
61
+ return x * scale + bias
62
+
63
+
64
+ class Joiner(nn.Module):
65
+ """A wrapper for the backbone and the position embedding.
66
+
67
+ Args:
68
+ backbone_cfg (Dict): Config dict to build backbone.
69
+ position_embedding_cfg (Dict): Config dict to build position embedding.
70
+ """
71
+
72
+ def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None:
73
+ super().__init__()
74
+ self.backbone = backbone
75
+ self.pos_embed = position_embedding
76
+
77
+ def forward(
78
+ self, tensor_list: NestedTensor
79
+ ) -> Union[List[NestedTensor], List[torch.Tensor]]:
80
+ """Forward function.
81
+
82
+ Args:
83
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
84
+
85
+ Returns:
86
+ [List[NestedTensor]: A list of feature map in NestedTensor format.
87
+ List[torch.Tensor]: A list of position encoding.
88
+ """
89
+
90
+ xs = self.backbone(tensor_list)
91
+ out: List[NestedTensor] = []
92
+ pos = []
93
+ for layer_idx, x in xs.items():
94
+ out.append(x)
95
+ # position encoding
96
+ pos.append(self.pos_embed(x).to(x.tensors.dtype))
97
+
98
+ return out, pos
99
+
100
+ def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor:
101
+ """Forward function for position embedding only. This is used to generate additional layer
102
+
103
+ Args:
104
+ x (NestedTensor): NestedTensor wrapping the input tensor.
105
+
106
+ Returns:
107
+ [List[torch.Tensor]: A list of position encoding.
108
+ """
109
+ return self.pos_embed(x)
110
+
111
+
112
+ @BACKBONES.register_module()
113
+ class SwinWrapper(nn.Module):
114
+ """A wrapper for swin transformer.
115
+
116
+ Args:
117
+ backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we
118
+ will call `get_swin_config` to get the config dict.
119
+ dilation (bool): Whether to use dilation in stage 4.
120
+ position_embedding_cfg (Dict): Config dict to build position embedding.
121
+ lr_backbone (float): Learning rate of the backbone.
122
+ return_interm_layers (List[int]): Which layers to return.
123
+ backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone.
124
+ use_checkpoint (bool): Whether to use checkpoint. Default: False.
125
+ ckpt_path (str): Checkpoint path. Default: None.
126
+ use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ backbone_cfg: Union[Dict, str],
132
+ dilation: bool,
133
+ position_embedding_cfg: Dict,
134
+ lr_backbone: float,
135
+ return_interm_indices: List[int],
136
+ backbone_freeze_keywords: List[str],
137
+ use_checkpoint: bool = False,
138
+ backbone_ckpt_path: str = None,
139
+ ) -> None:
140
+ super(SwinWrapper, self).__init__()
141
+ pos_embedding = build_position_embedding(position_embedding_cfg)
142
+ train_backbone = lr_backbone > 0
143
+ if not train_backbone:
144
+ raise ValueError("Please set lr_backbone > 0")
145
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
146
+
147
+ # build backbone
148
+ if isinstance(backbone_cfg, str):
149
+ assert (
150
+ backbone_cfg
151
+ in backbone_cfg
152
+ in [
153
+ "swin_T_224_1k",
154
+ "swin_B_224_22k",
155
+ "swin_B_384_22k",
156
+ "swin_L_224_22k",
157
+ "swin_L_384_22k",
158
+ ]
159
+ )
160
+ pretrain_img_size = int(backbone_cfg.split("_")[-2])
161
+ backbone_cfg = get_swin_config(
162
+ backbone_cfg,
163
+ pretrain_img_size,
164
+ out_indices=tuple(return_interm_indices),
165
+ dilation=dilation,
166
+ use_checkpoint=use_checkpoint,
167
+ )
168
+ backbone = build_backbone(backbone_cfg)
169
+
170
+ # freeze some layers
171
+ if backbone_freeze_keywords is not None:
172
+ for name, parameter in backbone.named_parameters():
173
+ for keyword in backbone_freeze_keywords:
174
+ if keyword in name:
175
+ parameter.requires_grad_(False)
176
+ break
177
+
178
+ # load checkpoint
179
+ if backbone_ckpt_path is not None:
180
+ print("Loading backbone checkpoint from {}".format(backbone_ckpt_path))
181
+ checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"]
182
+ from collections import OrderedDict
183
+
184
+ def key_select_function(keyname):
185
+ if "head" in keyname:
186
+ return False
187
+ if dilation and "layers.3" in keyname:
188
+ return False
189
+ return True
190
+
191
+ _tmp_st = OrderedDict(
192
+ {
193
+ k: v
194
+ for k, v in clean_state_dict(checkpoint).items()
195
+ if key_select_function(k)
196
+ }
197
+ )
198
+ _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
199
+ print(str(_tmp_st_output))
200
+
201
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
202
+ assert len(bb_num_channels) == len(
203
+ return_interm_indices
204
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
205
+
206
+ model = Joiner(backbone, pos_embedding)
207
+ model.num_channels = bb_num_channels
208
+ self.num_channels = bb_num_channels
209
+ self.model = model
210
+
211
+ def forward(
212
+ self, tensor_list: NestedTensor
213
+ ) -> Union[List[NestedTensor], List[torch.Tensor]]:
214
+ """Forward function.
215
+
216
+ Args:
217
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
218
+
219
+ Returns:
220
+ [List[NestedTensor]: A list of feature map in NestedTensor format.
221
+ List[torch.Tensor]: A list of position encoding.
222
+ """
223
+
224
+ return self.model(tensor_list)
225
+
226
+ def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor:
227
+ """Forward function to get position embedding only.
228
+
229
+ Args:
230
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
231
+
232
+ Returns:
233
+ torch.Tensor: Position embedding.
234
+ """
235
+ return self.model.forward_pos_embed_only(tensor_list)
236
+
237
+
238
+ def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw):
239
+ """Get swin config dict.
240
+
241
+ Args:
242
+ modelname (str): Name of the model.
243
+ pretrain_img_size (Tuple[int, int]): Image size of the pretrain model.
244
+ kw (Dict): Other key word arguments.
245
+
246
+ Returns:
247
+ Dict: Config dict.
248
+ str: Path to the pretrained checkpoint.
249
+ """
250
+ assert modelname in [
251
+ "swin_T_224_1k",
252
+ "swin_B_224_22k",
253
+ "swin_B_384_22k",
254
+ "swin_L_224_22k",
255
+ "swin_L_384_22k",
256
+ ]
257
+ model_para_dict = {
258
+ "swin_T_224_1k": dict(
259
+ type="SwinTransformer",
260
+ embed_dim=96,
261
+ depths=[2, 2, 6, 2],
262
+ num_heads=[3, 6, 12, 24],
263
+ window_size=7,
264
+ ),
265
+ "swin_B_224_22k": dict(
266
+ type="SwinTransformer",
267
+ embed_dim=128,
268
+ depths=[2, 2, 18, 2],
269
+ num_heads=[4, 8, 16, 32],
270
+ window_size=7,
271
+ ),
272
+ "swin_B_384_22k": dict(
273
+ type="SwinTransformer",
274
+ embed_dim=128,
275
+ depths=[2, 2, 18, 2],
276
+ num_heads=[4, 8, 16, 32],
277
+ window_size=12,
278
+ ),
279
+ "swin_L_224_22k": dict(
280
+ type="SwinTransformer",
281
+ embed_dim=192,
282
+ depths=[2, 2, 18, 2],
283
+ num_heads=[6, 12, 24, 48],
284
+ window_size=7,
285
+ ),
286
+ "swin_L_384_22k": dict(
287
+ type="SwinTransformer",
288
+ embed_dim=192,
289
+ depths=[2, 2, 18, 2],
290
+ num_heads=[6, 12, 24, 48],
291
+ window_size=12,
292
+ ),
293
+ }
294
+ kw_cgf = model_para_dict[modelname]
295
+ kw_cgf.update(kw)
296
+ kw_cgf.update(dict(pretrain_img_size=pretrain_img_size))
297
+ return kw_cgf
detect_tools/upn/models/decoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .upn_decoder import UPNDecoder, DeformableTransformerDecoderLayer
2
+
3
+ __all__ = ["UPNDecoder", "DeformableTransformerDecoderLayer"]
detect_tools/upn/models/decoder/upn_decoder.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from detect_tools.upn import DECODERS, build_decoder
7
+ from detect_tools.upn.models.module import MLP
8
+ from detect_tools.upn.models.utils import (gen_sineembed_for_position,
9
+ get_activation_fn, get_clones,
10
+ inverse_sigmoid)
11
+ from detect_tools.upn.ops.modules import MSDeformAttn
12
+
13
+
14
+ @DECODERS.register_module()
15
+ class DeformableTransformerDecoderLayer(nn.Module):
16
+ """Deformable Transformer Decoder Layer. This is a modified version in Grounding DINO.
17
+ After the query is attented to the image feature, it is further attented to the text feature.
18
+ The execute order is: self_attn -> cross_attn to text -> cross_attn to image -> ffn
19
+ Args:
20
+ d_model (int): The dimension of keys/values/queries in :class:`MultiheadAttention`.
21
+ d_ffn (int): The dimension of the feedforward network model.
22
+ dropout (float): Probability of an element to be zeroed.
23
+ activation (str): Activation function in the feedforward network.
24
+ 'relu' and 'gelu' are supported.
25
+ n_levels (int): The number of levels in Multi-Scale Deformable Attention.
26
+ n_heads (int): Parallel attention heads.
27
+ n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
28
+ ffn_extra_layernorm (bool): If True, add an extra layernorm after ffn.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ d_model: int = 256,
34
+ d_ffn: int = 1024,
35
+ dropout: float = 0.1,
36
+ activation: str = "relu",
37
+ n_levels: int = 4,
38
+ n_heads: int = 8,
39
+ n_points: int = 4,
40
+ ffn_extra_layernorm: bool = False,
41
+ ) -> None:
42
+ super().__init__()
43
+
44
+ # cross attention for visual features
45
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
46
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
47
+ self.norm1 = nn.LayerNorm(d_model)
48
+
49
+ # self attention for query
50
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
51
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
52
+ self.norm2 = nn.LayerNorm(d_model)
53
+
54
+ # ffn
55
+ self.linear1 = nn.Linear(d_model, d_ffn)
56
+ self.activation = get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
57
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
58
+ self.linear2 = nn.Linear(d_ffn, d_model)
59
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
60
+ self.norm3 = nn.LayerNorm(d_model)
61
+ if ffn_extra_layernorm:
62
+ raise NotImplementedError("ffn_extra_layernorm not implemented")
63
+ self.norm_ext = nn.LayerNorm(d_ffn)
64
+ else:
65
+ self.norm_ext = None
66
+
67
+ self.key_aware_proj = None
68
+
69
+ def rm_self_attn_modules(self):
70
+ self.self_attn = None
71
+ self.dropout2 = None
72
+ self.norm2 = None
73
+
74
+ @staticmethod
75
+ def with_pos_embed(tensor, pos):
76
+ return tensor if pos is None else tensor + pos
77
+
78
+ def forward_ffn(self, tgt):
79
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
80
+
81
+ tgt = tgt + self.dropout4(tgt2)
82
+ tgt = self.norm3(tgt)
83
+ return tgt
84
+
85
+ def forward(
86
+ self,
87
+ tgt: torch.Tensor,
88
+ tgt_query_pos: torch.Tensor = None,
89
+ tgt_reference_points: torch.Tensor = None,
90
+ memory: torch.Tensor = None,
91
+ memory_key_padding_mask: torch.Tensor = None,
92
+ memory_level_start_index: torch.Tensor = None,
93
+ memory_spatial_shapes: torch.Tensor = None,
94
+ self_attn_mask: torch.Tensor = None,
95
+ cross_attn_mask: torch.Tensor = None,
96
+ ) -> torch.Tensor:
97
+ """Forward function
98
+
99
+ Args:
100
+ tgt (torch.Tensor): Input target in shape (B, T, C)
101
+ tgt_query_pos (torch.Tensor): Positional encoding of the query.
102
+ tgt_query_sine_embed (torch.Tensor): Sine positional encoding of the query. Unused.
103
+ tgt_key_padding_mask (torch.Tensor): Mask for target feature in shape (B, T).
104
+ tgt_reference_points (torch.Tensor): Reference points for the query in shape (B, T, 4).
105
+ memory_text (torch.Tensor): Input text embeddings in shape (B, num_token, C).
106
+ text_attention_mask (torch.Tensor): Attention mask for text embeddings in shape
107
+ (B, num_token).
108
+ memory (torch.Tensor): Input image feature in shape (B, HW, C)
109
+ memory_key_padding_mask (torch.Tensor): Mask for image feature in shape (B, HW)
110
+ memory_level_start_index (torch.Tensor): Starting index of each level in memory.
111
+ memory_spatial_shapes (torch.Tensor): Spatial shape of each level in memory.
112
+ memory_pos (torch.Tensor): Positional encoding of memory. Unused.
113
+ self_attn_mask (torch.Tensor): Mask used for self-attention.
114
+ cross_attn_mask (torch.Tensor): Mask used for cross-attention.
115
+
116
+ Returns:
117
+ torch.Tensor: Output tensor in shape (B, T, C)
118
+ """
119
+ assert cross_attn_mask is None
120
+
121
+ # self attention
122
+ if self.self_attn is not None:
123
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
124
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
125
+ tgt = tgt + self.dropout2(tgt2)
126
+ tgt = self.norm2(tgt)
127
+
128
+ # attend to image features
129
+ tgt2 = self.cross_attn(
130
+ self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
131
+ tgt_reference_points.transpose(0, 1).contiguous(),
132
+ memory.transpose(0, 1),
133
+ memory_spatial_shapes,
134
+ memory_level_start_index,
135
+ memory_key_padding_mask,
136
+ ).transpose(0, 1)
137
+ tgt = tgt + self.dropout1(tgt2)
138
+ tgt = self.norm1(tgt)
139
+ # ffn
140
+ tgt = self.forward_ffn(tgt)
141
+
142
+ return tgt
143
+
144
+
145
+ @DECODERS.register_module()
146
+ class UPNDecoder(nn.Module):
147
+ """Decoder used in UPN. Each layer is a DeformableTransformerDecoderLayer. The query
148
+ will be abled to attend the image feature and text feature. The execute order is:
149
+ self_attn -> cross_attn to image -> ffn
150
+
151
+ Args:
152
+ decoder_layer_cfg (Dict): Config for the DeformableTransformerDecoderLayer.
153
+ num_layers (int): number of layers
154
+ norm (nn.Module, optional): normalization layer. Defaults to None.
155
+ return_intermediate (bool, optional): whether return intermediate results.
156
+ Defaults to False.
157
+ d_model (int, optional): dimension of the model. Defaults to 256.
158
+ query_dim (int, optional): dimension of the query. Defaults to 4.
159
+ modulate_hw_attn (bool, optional): whether modulate the attention weights
160
+ by the height and width of the image feature. Defaults to False.
161
+ num_feature_levels (int, optional): number of feature levels. Defaults to 1.
162
+ deformable_decoder (bool, optional): whether use deformable decoder. Defaults to False.
163
+ decoder_query_perturber ([type], optional): [description]. Defaults to None.
164
+ dec_layer_number ([type], optional): [description]. Defaults to None.
165
+ rm_dec_query_scale (bool, optional): [description]. Defaults to False.
166
+ dec_layer_share (bool, optional): [description]. Defaults to False.
167
+ dec_layer_dropout_prob ([type], optional): [description]. Defaults to None.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ decoder_layer_cfg: Dict,
173
+ num_layers: int,
174
+ norm: str = "layernorm",
175
+ return_intermediate: bool = True,
176
+ d_model: int = 256,
177
+ query_dim: int = 4,
178
+ modulate_hw_attn: bool = False,
179
+ num_feature_levels: int = 1,
180
+ deformable_decoder: bool = True,
181
+ decoder_query_perturber=None,
182
+ dec_layer_number=None,
183
+ rm_dec_query_scale: bool = True,
184
+ dec_layer_share: bool = False,
185
+ dec_layer_dropout_prob=None,
186
+ use_detached_boxes_dec_out: bool = False,
187
+ ):
188
+ super().__init__()
189
+
190
+ decoder_layer = build_decoder(decoder_layer_cfg)
191
+ if num_layers > 0:
192
+ self.layers = get_clones(
193
+ decoder_layer, num_layers, layer_share=dec_layer_share
194
+ )
195
+ else:
196
+ self.layers = []
197
+ self.num_layers = num_layers
198
+ if norm == "layernorm":
199
+ self.norm = nn.LayerNorm(d_model)
200
+ self.return_intermediate = return_intermediate
201
+ self.query_dim = query_dim
202
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
203
+ self.num_feature_levels = num_feature_levels
204
+ self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
205
+
206
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
207
+ self.ref_point_head_point = MLP(
208
+ d_model, d_model, d_model, 2
209
+ ) # for point reference only
210
+ if not deformable_decoder:
211
+ self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
212
+ else:
213
+ self.query_pos_sine_scale = None
214
+
215
+ if rm_dec_query_scale:
216
+ self.query_scale = None
217
+ else:
218
+ raise NotImplementedError
219
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
220
+ self.bbox_embed = None
221
+ self.class_embed = None
222
+
223
+ self.d_model = d_model
224
+ self.modulate_hw_attn = modulate_hw_attn
225
+ self.deformable_decoder = deformable_decoder
226
+
227
+ if not deformable_decoder and modulate_hw_attn:
228
+ self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
229
+ else:
230
+ self.ref_anchor_head = None
231
+
232
+ self.decoder_query_perturber = decoder_query_perturber
233
+ self.box_pred_damping = None
234
+
235
+ self.dec_layer_number = dec_layer_number
236
+ if dec_layer_number is not None:
237
+ assert isinstance(dec_layer_number, list)
238
+ assert len(dec_layer_number) == num_layers
239
+
240
+ self.dec_layer_dropout_prob = dec_layer_dropout_prob
241
+ if dec_layer_dropout_prob is not None:
242
+ assert isinstance(dec_layer_dropout_prob, list)
243
+ assert len(dec_layer_dropout_prob) == num_layers
244
+ for i in dec_layer_dropout_prob:
245
+ assert 0.0 <= i <= 1.0
246
+
247
+ self.rm_detach = None
248
+
249
+ def forward(
250
+ self,
251
+ tgt: torch.Tensor,
252
+ memory: torch.Tensor,
253
+ tgt_mask: torch.Tensor = None,
254
+ memory_mask: torch.Tensor = None,
255
+ tgt_key_padding_mask: torch.Tensor = None,
256
+ memory_key_padding_mask: torch.Tensor = None,
257
+ pos: torch.Tensor = None,
258
+ refpoints_unsigmoid: torch.Tensor = None,
259
+ level_start_index: torch.Tensor = None,
260
+ spatial_shapes: torch.Tensor = None,
261
+ valid_ratios: torch.Tensor = None,
262
+ memory_ref_image: torch.Tensor = None,
263
+ refImg_padding_mask: torch.Tensor = None,
264
+ memory_visual_prompt: torch.Tensor = None,
265
+ ):
266
+ """Forward function.
267
+
268
+ Args:
269
+ tgt (torch.Tensor): target feature, [bs, num_queries, d_model]
270
+ memory (torch.Tensor): Image feature, [bs, hw, d_model]
271
+ tgt_mask (torch.Tensor, optional): target mask for attention. Defaults to None.
272
+ memory_mask (torch.Tensor, optional): image mask for attention. Defaults to None.
273
+ tgt_key_padding_mask (torch.Tensor, optional): target mask for padding. Defaults to None.
274
+ memory_key_padding_mask (torch.Tensor, optional): image mask for padding. Defaults to None.
275
+ pos (torch.Tensor, optional): query position embedding
276
+ refpoints_unsigmoid (torch.Tensor, optional): reference points. Defaults to None.
277
+ level_start_index (torch.Tensor, optional): start index of each level. Defaults to None.
278
+ spatial_shapes (torch.Tensor, optional): spatial shape of each level. Defaults to None.
279
+ valid_ratios (torch.Tensor, optional): valid ratio of each level. Defaults to None.
280
+ memory_ref_image (torch.Tensor, optional): reference image feature, [bs, num_ref, d_model]. Defaults to None.
281
+ refImg_padding_mask (torch.Tensor, optional): padding mask for attention. Defaults to None.
282
+ """
283
+ output = tgt
284
+
285
+ intermediate = []
286
+ reference_points = refpoints_unsigmoid.sigmoid()
287
+ ref_points = [reference_points]
288
+
289
+ for layer_id, layer in enumerate(self.layers):
290
+
291
+ if reference_points.shape[-1] == 4:
292
+ reference_points_input = (
293
+ reference_points[:, :, None]
294
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
295
+ ) # nq, bs, nlevel, 4
296
+ else:
297
+ assert reference_points.shape[-1] == 2
298
+ reference_points_input = (
299
+ reference_points[:, :, None] * valid_ratios[None, :]
300
+ )
301
+ query_sine_embed = gen_sineembed_for_position(
302
+ reference_points_input[:, :, 0, :]
303
+ ) # nq, bs, 256*2
304
+
305
+ # conditional query
306
+ if query_sine_embed.shape[-1] == 512:
307
+ raw_query_pos = (
308
+ self.ref_point_head(query_sine_embed)
309
+ + self.ref_point_head_point(
310
+ torch.zeros_like(query_sine_embed)[:, :, :256]
311
+ )
312
+ * 0.0
313
+ )
314
+ else:
315
+ raw_query_pos = (
316
+ self.ref_point_head_point(query_sine_embed)
317
+ + self.ref_point_head(
318
+ torch.zeros(
319
+ query_sine_embed.shape[0],
320
+ query_sine_embed.shape[1],
321
+ 512,
322
+ device=query_sine_embed.device,
323
+ )
324
+ )
325
+ * 0.0
326
+ )
327
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
328
+ query_pos = pos_scale * raw_query_pos
329
+
330
+ # main process
331
+ output = layer(
332
+ tgt=output,
333
+ tgt_query_pos=query_pos,
334
+ tgt_reference_points=reference_points_input,
335
+ memory=memory,
336
+ memory_key_padding_mask=memory_key_padding_mask,
337
+ memory_level_start_index=level_start_index,
338
+ memory_spatial_shapes=spatial_shapes,
339
+ self_attn_mask=tgt_mask,
340
+ cross_attn_mask=memory_mask,
341
+ )
342
+ if output.isnan().any() | output.isinf().any():
343
+ print(f"output layer_id {layer_id} is nan")
344
+ try:
345
+ num_nan = output.isnan().sum().item()
346
+ num_inf = output.isinf().sum().item()
347
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
348
+ except Exception as e:
349
+ print(e)
350
+
351
+ # iter update
352
+ if self.bbox_embed is not None:
353
+
354
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
355
+ delta_unsig = self.bbox_embed[layer_id](output)
356
+ outputs_unsig = delta_unsig + reference_before_sigmoid
357
+ new_reference_points = outputs_unsig.sigmoid()
358
+
359
+ if self.rm_detach and "dec" in self.rm_detach:
360
+ reference_points = new_reference_points
361
+ else:
362
+ reference_points = new_reference_points.detach()
363
+
364
+ if self.use_detached_boxes_dec_out:
365
+ ref_points.append(reference_points)
366
+ else:
367
+ ref_points.append(new_reference_points)
368
+
369
+ if self.return_intermediate:
370
+ intermediate.append(self.norm(output))
371
+
372
+ if self.return_intermediate:
373
+ return [
374
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
375
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
376
+ ]
377
+ else:
378
+ return self.norm(output).transpose(0, 1)
detect_tools/upn/models/encoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .upn_encoder import DeformableTransformerEncoderLayer, UPNEncoder
2
+
3
+ __all__ = ["UPNEncoder", "DeformableTransformerEncoderLayer"]
detect_tools/upn/models/encoder/upn_encoder.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.checkpoint as checkpoint
6
+
7
+ from detect_tools.upn import ENCODERS, build_encoder
8
+ from detect_tools.upn.models.utils import get_activation_fn, get_clones
9
+ from detect_tools.upn.ops.modules import MSDeformAttn
10
+
11
+
12
+ @ENCODERS.register_module()
13
+ class DeformableTransformerEncoderLayer(nn.Module):
14
+ """Deformable Transformer Encoder Layer.
15
+
16
+ Args:
17
+ d_model (int): The dimension of keys/values/queries in
18
+ :class:`MultiheadAttention`.
19
+ d_ffn (int): The dimension of the feedforward network model.
20
+ dropout (float): Probability of an element to be zeroed.
21
+ activation (str): Activation function in the feedforward network.
22
+ 'relu' and 'gelu' are supported.
23
+ n_levels (int): The number of levels in Multi-Scale Deformable Attention.
24
+ n_heads (int): Parallel attention heads.
25
+ n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
26
+ add_channel_attention (bool): If True, add channel attention.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ d_model: int = 256,
32
+ d_ffn: int = 1024,
33
+ dropout: float = 0.1,
34
+ activation: str = "relu",
35
+ n_levels: int = 4,
36
+ n_heads: int = 8,
37
+ n_points: int = 4,
38
+ add_channel_attention: bool = False,
39
+ ) -> None:
40
+ super().__init__()
41
+
42
+ # self attention
43
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
44
+ self.dropout1 = nn.Dropout(dropout)
45
+ self.norm1 = nn.LayerNorm(d_model)
46
+
47
+ # ffn
48
+ self.linear1 = nn.Linear(d_model, d_ffn)
49
+ self.activation = get_activation_fn(activation, d_model=d_ffn)
50
+ self.dropout2 = nn.Dropout(dropout)
51
+ self.linear2 = nn.Linear(d_ffn, d_model)
52
+ self.dropout3 = nn.Dropout(dropout)
53
+ self.norm2 = nn.LayerNorm(d_model)
54
+
55
+ # channel attention
56
+ self.add_channel_attention = add_channel_attention
57
+ if add_channel_attention:
58
+ self.activ_channel = get_activation_fn("dyrelu", d_model=d_model)
59
+ self.norm_channel = nn.LayerNorm(d_model)
60
+
61
+ @staticmethod
62
+ def with_pos_embed(tensor, pos):
63
+ return tensor if pos is None else tensor + pos
64
+
65
+ def forward_ffn(self, src: torch.Tensor) -> torch.Tensor:
66
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
67
+ src = src + self.dropout3(src2)
68
+ src = self.norm2(src)
69
+ return src
70
+
71
+ def forward(
72
+ self,
73
+ src: torch.Tensor,
74
+ pos: torch.Tensor,
75
+ reference_points: torch.Tensor,
76
+ spatial_shapes: torch.Tensor,
77
+ level_start_index: torch.Tensor,
78
+ key_padding_mask: torch.Tensor = None,
79
+ ) -> torch.Tensor:
80
+ """Forward function for `DeformableTransformerEncoderLayer`.
81
+
82
+ Args:
83
+ src (torch.Tensor): The input sequence of shape (S, N, E).
84
+ pos (torch.Tensor): The position embedding of shape (S, N, E).
85
+ reference_points (torch.Tensor): The reference points of shape (N, L, 2).
86
+ spatial_shapes (torch.Tensor): The spatial shapes of feature levels.
87
+ level_start_index (torch.Tensor): The start index of each level.
88
+ key_padding_mask (torch.Tensor): The mask for keys with shape (N, S).
89
+ """
90
+ # self attention
91
+ # import ipdb; ipdb.set_trace()
92
+ src2 = self.self_attn(
93
+ self.with_pos_embed(src, pos),
94
+ reference_points,
95
+ src,
96
+ spatial_shapes,
97
+ level_start_index,
98
+ key_padding_mask,
99
+ )
100
+ src = src + self.dropout1(src2)
101
+ src = self.norm1(src)
102
+
103
+ # ffn
104
+ src = self.forward_ffn(src)
105
+
106
+ # channel attn
107
+ if self.add_channel_attention:
108
+ src = self.norm_channel(src + self.activ_channel(src))
109
+
110
+ return src
111
+
112
+
113
+ @ENCODERS.register_module()
114
+ class UPNEncoder(nn.Module):
115
+ """Implementation of UPN Encoder.
116
+
117
+ Args:
118
+ num_layers (int): The number of layers in the TransformerEncoder.
119
+ d_model (int, optional): The dimension of the input feature. Defaults to 256.
120
+ encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer.
121
+ use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for
122
+ memory saving. Defaults to False.
123
+ use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder.
124
+ enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer.
125
+ Defaults to False. This is used for all the sub-layers in the basic block.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ num_layers: int,
131
+ d_model: int = 256,
132
+ encoder_layer_cfg: Dict = None,
133
+ use_checkpoint: bool = True,
134
+ use_transformer_ckpt: bool = True,
135
+ enc_layer_share: bool = False,
136
+ multi_level_encoder_fusion: str = None,
137
+ ):
138
+ super().__init__()
139
+ # prepare layers
140
+ self.layers = []
141
+ self.refImg_layers = []
142
+ self.fusion_layers = []
143
+ encoder_layer = build_encoder(encoder_layer_cfg)
144
+
145
+ self.multi_level_encoder_fusion = multi_level_encoder_fusion
146
+ self._initilize_memory_fusion_layers(
147
+ multi_level_encoder_fusion, num_layers, d_model
148
+ )
149
+
150
+ if num_layers > 0:
151
+ self.layers = get_clones(
152
+ encoder_layer, num_layers, layer_share=enc_layer_share
153
+ )
154
+ else:
155
+ self.layers = []
156
+ del encoder_layer
157
+
158
+ self.query_scale = None
159
+ self.num_layers = num_layers
160
+ self.d_model = d_model
161
+
162
+ self.use_checkpoint = use_checkpoint
163
+ self.use_transformer_ckpt = use_transformer_ckpt
164
+
165
+ def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model):
166
+ if fusion_type is None:
167
+ self.memory_fusion_layer = None
168
+ return
169
+
170
+ assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"]
171
+ if fusion_type == "stable_dense_fusion":
172
+ self.memory_fusion_layer = nn.Sequential(
173
+ nn.Linear(d_model * (num_layers + 1), d_model),
174
+ nn.LayerNorm(d_model),
175
+ )
176
+ nn.init.constant_(self.memory_fusion_layer[0].bias, 0)
177
+ elif fusion_type == "dense_net_fusion":
178
+ self.memory_fusion_layer = nn.ModuleList()
179
+ for i in range(num_layers):
180
+ self.memory_fusion_layer.append(
181
+ nn.Sequential(
182
+ nn.Linear(
183
+ d_model * (i + 2), d_model
184
+ ), # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256
185
+ nn.LayerNorm(d_model),
186
+ )
187
+ )
188
+ for layer in self.memory_fusion_layer:
189
+ nn.init.constant_(layer[0].bias, 0)
190
+ else:
191
+ raise NotImplementedError
192
+
193
+ @staticmethod
194
+ def get_reference_points(spatial_shapes, valid_ratios, device):
195
+ reference_points_list = []
196
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
197
+
198
+ ref_y, ref_x = torch.meshgrid(
199
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
200
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
201
+ )
202
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
203
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
204
+ ref = torch.stack((ref_x, ref_y), -1)
205
+ reference_points_list.append(ref)
206
+ reference_points = torch.cat(reference_points_list, 1)
207
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
208
+ return reference_points
209
+
210
+ def forward(
211
+ self,
212
+ src: torch.Tensor,
213
+ pos: torch.Tensor,
214
+ spatial_shapes: torch.Tensor,
215
+ level_start_index: torch.Tensor,
216
+ valid_ratios: torch.Tensor,
217
+ key_padding_mask: torch.Tensor = None,
218
+ ):
219
+ """Forward function
220
+
221
+ Args:
222
+ src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256]
223
+ pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256]
224
+ spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2]
225
+ level_start_index (torch.Tensor): Start index of each level in shape [num_level]
226
+ valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2]
227
+ key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)]
228
+ memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults
229
+ to None.
230
+ refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature
231
+ in shape [bs, n_text]. Defaults to None.
232
+ pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape
233
+ [bs, n_ref, 256]. Defaults to None.
234
+ refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference
235
+ image feature in shape [bs, n_ref, n_ref]. Defaults to None.
236
+ Outpus:
237
+ torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256]
238
+ torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256]
239
+ """
240
+
241
+ output = src
242
+ # preparation and reshape
243
+ if self.num_layers > 0:
244
+ reference_points = self.get_reference_points(
245
+ spatial_shapes, valid_ratios, device=src.device
246
+ )
247
+
248
+ # multi-level dense fusion
249
+ output_list = [output]
250
+ # main process
251
+ for layer_id, layer in enumerate(self.layers):
252
+ # main process
253
+ if self.use_transformer_ckpt:
254
+ output = checkpoint.checkpoint(
255
+ layer,
256
+ output,
257
+ pos,
258
+ reference_points,
259
+ spatial_shapes,
260
+ level_start_index,
261
+ key_padding_mask,
262
+ )
263
+ else:
264
+ output = layer(
265
+ src=output,
266
+ pos=pos,
267
+ reference_points=reference_points,
268
+ spatial_shapes=spatial_shapes,
269
+ level_start_index=level_start_index,
270
+ key_padding_mask=key_padding_mask,
271
+ )
272
+
273
+ output_list.append(output)
274
+ if (
275
+ self.multi_level_encoder_fusion is not None
276
+ and self.multi_level_encoder_fusion == "dense_net_fusion"
277
+ ):
278
+ output = self.memory_fusion_layer[layer_id](
279
+ torch.cat(output_list, dim=-1)
280
+ )
281
+
282
+ if (
283
+ self.multi_level_encoder_fusion is not None
284
+ and self.multi_level_encoder_fusion == "stable_dense_fusion"
285
+ ):
286
+ output = self.memory_fusion_layer(torch.cat(output_list, dim=-1))
287
+
288
+ return output
detect_tools/upn/models/module/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .contrastive import ContrastiveAssign
2
+ from .mlp import MLP
3
+ from .nested_tensor import NestedTensor, nested_tensor_from_tensor_list
4
+
5
+ __all__ = ["MLP", "NestedTensor", "nested_tensor_from_tensor_list", "ContrastiveAssign"]
detect_tools/upn/models/module/contrastive.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ContrastiveAssign(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ cal_bias: nn.Module = None,
13
+ ) -> None:
14
+ """Lanuage-Image Contrastive Assignment used to calculate the similarity between
15
+ the text and the image.
16
+
17
+ Args:
18
+ cal_bias (nn.Module, optional): The bias used to calculate the similarity.
19
+ Defaults to None.
20
+ max_text_len (int, optional): The max length of the text. Defaults to 256.
21
+ """
22
+ super().__init__()
23
+ self.cal_bias = cal_bias
24
+
25
+ def forward(self, x: torch.Tensor, ref_dict: Dict):
26
+
27
+ y = ref_dict["encoded_ref_feature"]
28
+ res = x @ y.transpose(-1, -2)
29
+ return res
detect_tools/upn/models/module/mlp.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class MLP(nn.Module):
6
+ """ Very simple multi-layer perceptron (also called FFN)"""
7
+
8
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
9
+ super().__init__()
10
+ self.num_layers = num_layers
11
+ h = [hidden_dim] * (num_layers - 1)
12
+ self.layers = nn.ModuleList(
13
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
14
+
15
+ def forward(self, x):
16
+ for i, layer in enumerate(self.layers):
17
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
18
+ return x
detect_tools/upn/models/module/nested_tensor.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ import torchvision
5
+
6
+
7
+ class NestedTensor(object):
8
+ """Define a NestedTensor class
9
+
10
+ Args:
11
+ tensors (torch.Tensor): Tensor with shape [batch, C, H, W] or [C, H, W]
12
+ mask (Union[torch.Tensor, str]): mask with shape [batch, H, W] or [H, W]. If mask
13
+ is 'auto', it will be generated automatically by summing the tensor along
14
+ the channel dimension. Mask is used to indicate the padding area.
15
+ """
16
+
17
+ def __init__(
18
+ self, tensors: torch.Tensor, mask: Union[torch.Tensor, str] = "auto"
19
+ ) -> None:
20
+ self.tensors = tensors
21
+ self.mask = mask
22
+ if mask == "auto":
23
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
24
+ if self.mask.dim() == 3:
25
+ self.mask = self.mask.sum(0).to(bool)
26
+ elif self.mask.dim() == 4:
27
+ self.mask = self.mask.sum(1).to(bool)
28
+ else:
29
+ raise ValueError(
30
+ "tensors dim must be 3 or 4 but {}({})".format(
31
+ self.tensors.dim(), self.tensors.shape
32
+ )
33
+ )
34
+
35
+ def imgsize(self) -> List[torch.Tensor]:
36
+ """get the img size of the tensor
37
+
38
+ Returns:
39
+ list[torch.Tensor]: list of tensor with shape [2] which is [H, W]
40
+ """
41
+ res = []
42
+ for i in range(self.tensors.shape[0]):
43
+ mask = self.mask[i]
44
+ maxH = (~mask).sum(0).max()
45
+ maxW = (~mask).sum(1).max()
46
+ res.append(torch.Tensor([maxH, maxW]))
47
+ return res
48
+
49
+ def to(self, device: torch.device):
50
+ """Move tensors and mask to the given device
51
+
52
+ Args:
53
+ device (torch.device): device to move
54
+
55
+ Returns:
56
+ NestedTensor: moved NestedTensor
57
+ """
58
+ cast_tensor = self.tensors.to(device)
59
+ mask = self.mask
60
+ if mask is not None:
61
+ assert mask is not None
62
+ cast_mask = mask.to(device)
63
+ else:
64
+ cast_mask = None
65
+ return NestedTensor(cast_tensor, cast_mask)
66
+
67
+ def to_img_list_single(
68
+ self, tensor: torch.Tensor, mask: torch.Tensor
69
+ ) -> torch.Tensor:
70
+ """remove the padding for one image
71
+
72
+ Args:
73
+ tensor (torch.Tensor): tensor with shape [C, H, W]
74
+ mask (torch.Tensor): mask with shape [H, W]
75
+
76
+ Returns:
77
+ torch.Tensor: tensor with shape [C, maxH, maxW]
78
+ """
79
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(
80
+ tensor.dim()
81
+ )
82
+ maxH = (~mask).sum(0).max()
83
+ maxW = (~mask).sum(1).max()
84
+ img = tensor[:, :maxH, :maxW]
85
+ return img
86
+
87
+ def to_img_list(self) -> List[torch.Tensor]:
88
+ """remove the padding and convert to img list
89
+
90
+ Returns:
91
+ list[torch.Tensor]: list of tensor with shape [C, maxH, maxW]
92
+ """
93
+ if self.tensors.dim() == 3:
94
+ return self.to_img_list_single(self.tensors, self.mask)
95
+ else:
96
+ res = []
97
+ for i in range(self.tensors.shape[0]):
98
+ tensor_i = self.tensors[i]
99
+ mask_i = self.mask[i]
100
+ res.append(self.to_img_list_single(tensor_i, mask_i))
101
+ return res
102
+
103
+ @property
104
+ def device(self):
105
+ return self.tensors.device
106
+
107
+ def decompose(self):
108
+ return self.tensors, self.mask
109
+
110
+ def __repr__(self):
111
+ return str(self.tensors)
112
+
113
+ @property
114
+ def shape(self):
115
+ return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
116
+
117
+
118
+ def _max_by_axis(the_list):
119
+ # type: (List[List[int]]) -> List[int]
120
+ maxes = the_list[0]
121
+ for sublist in the_list[1:]:
122
+ for index, item in enumerate(sublist):
123
+ maxes[index] = max(maxes[index], item)
124
+ return maxes
125
+
126
+
127
+ def nested_tensor_from_tensor_list(
128
+ tensor_list: List[torch.Tensor], fixed_img_size=None
129
+ ):
130
+ if fixed_img_size is not None:
131
+ if isinstance(fixed_img_size, (list, tuple)):
132
+ assert (
133
+ len(fixed_img_size) == 2
134
+ ), "image size should be a tuple or list with two elements"
135
+ elif isinstance(fixed_img_size, int):
136
+ fixed_img_size = [fixed_img_size, fixed_img_size]
137
+
138
+ if tensor_list[0].ndim == 3:
139
+ if torchvision._is_tracing():
140
+ # nested_tensor_from_tensor_list() does not export well to ONNX
141
+ # call _onnx_nested_tensor_from_tensor_list() instead
142
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
143
+
144
+ # TODO make it support different-sized images
145
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
146
+
147
+ if fixed_img_size is not None:
148
+ c, orig_h, orig_w = max_size
149
+ assert (
150
+ orig_h <= fixed_img_size[0] and orig_w <= fixed_img_size[1]
151
+ ), f"{orig_h} {orig_w} the fixed output image size should be larger than original image"
152
+ max_size = [c, fixed_img_size[0], fixed_img_size[1]]
153
+
154
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
155
+ batch_shape = [len(tensor_list)] + max_size
156
+ b, c, h, w = batch_shape
157
+ dtype = tensor_list[0].dtype
158
+ device = tensor_list[0].device
159
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
160
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
161
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
162
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
163
+ m[: img.shape[1], : img.shape[2]] = False
164
+ else:
165
+ raise ValueError("not supported")
166
+ return NestedTensor(tensor, mask)
167
+
168
+
169
+ @torch.jit.unused
170
+ def _onnx_nested_tensor_from_tensor_list(
171
+ tensor_list: List[torch.Tensor],
172
+ ) -> NestedTensor:
173
+ max_size = []
174
+ for i in range(tensor_list[0].dim()):
175
+ max_size_i = torch.max(
176
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
177
+ ).to(torch.int64)
178
+ max_size.append(max_size_i)
179
+ max_size = tuple(max_size)
180
+
181
+ padded_imgs = []
182
+ padded_masks = []
183
+ for img in tensor_list:
184
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
185
+ padded_img = torch.nn.functional.pad(
186
+ img, (0, padding[2], 0, padding[1], 0, padding[0])
187
+ )
188
+ padded_imgs.append(padded_img)
189
+
190
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
191
+ padded_mask = torch.nn.functional.pad(
192
+ m, (0, padding[2], 0, padding[1]), "constant", 1
193
+ )
194
+ padded_masks.append(padded_mask.to(torch.bool))
195
+
196
+ tensor = torch.stack(padded_imgs)
197
+ mask = torch.stack(padded_masks)
198
+
199
+ return NestedTensor(tensor, mask=mask)
detect_tools/upn/models/utils/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .detr_utils import (
2
+ PositionEmbeddingLearned,
3
+ PositionEmbeddingSine,
4
+ PositionEmbeddingSineHW,
5
+ clean_state_dict,
6
+ gen_encoder_output_proposals,
7
+ gen_sineembed_for_position,
8
+ get_activation_fn,
9
+ get_clones,
10
+ inverse_sigmoid,
11
+ )
12
+
13
+ __all__ = [
14
+ "inverse_sigmoid",
15
+ "gen_encoder_output_proposals",
16
+ "get_clones",
17
+ "gen_sineembed_for_position",
18
+ "get_activation_fn",
19
+ "clean_state_dict",
20
+ "PositionEmbeddingSine",
21
+ "PositionEmbeddingSineHW",
22
+ "PositionEmbeddingLearned",
23
+ ]
detect_tools/upn/models/utils/detr_utils.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from detect_tools.upn import POS_EMBEDDINGS
11
+ from detect_tools.upn.models.module import NestedTensor
12
+
13
+
14
+ @POS_EMBEDDINGS.register_module()
15
+ class PositionEmbeddingSine(nn.Module):
16
+ """This is a more standard version of the position embedding, very similar to the one
17
+ used by the Attention is all you need paper, generalized to work on images.
18
+
19
+ Args:
20
+ num_pos_feats (int): The channel of positional embeddings.
21
+ temperature (float): The temperature used in positional embeddings.
22
+ normalize (bool): Whether to normalize the positional embeddings.
23
+ scale (float): The scale factor of positional embeddings.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ num_pos_feats: int = 64,
29
+ temperature: int = 10000,
30
+ normalize: bool = False,
31
+ scale: float = None,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.num_pos_feats = num_pos_feats
35
+ self.temperature = temperature
36
+ self.normalize = normalize
37
+ if scale is not None and normalize is False:
38
+ raise ValueError("normalize should be True if scale is passed")
39
+ if scale is None:
40
+ scale = 2 * math.pi
41
+ self.scale = scale
42
+
43
+ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
44
+ """Forward function.
45
+
46
+ Args:
47
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
51
+ """
52
+ x = tensor_list.tensors
53
+ mask = tensor_list.mask
54
+ assert mask is not None
55
+ not_mask = ~mask
56
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
57
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
58
+ if self.normalize:
59
+ eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ @POS_EMBEDDINGS.register_module()
79
+ class PositionEmbeddingSineHW(nn.Module):
80
+ """This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+
83
+ Args:
84
+ num_pos_feats (int): The channel of positional embeddings.
85
+ temperatureH (float): The temperature used in positional embeddings.
86
+ temperatureW (float): The temperature used in positional embeddings.
87
+ normalize (bool): Whether to normalize the positional embeddings.
88
+ scale (float): The scale factor of positional embeddings.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ num_pos_feats: int = 64,
94
+ temperatureH: int = 10000,
95
+ temperatureW: int = 10000,
96
+ normalize: bool = False,
97
+ scale: float = None,
98
+ ) -> None:
99
+ super().__init__()
100
+ self.num_pos_feats = num_pos_feats
101
+ self.temperatureH = temperatureH
102
+ self.temperatureW = temperatureW
103
+ self.normalize = normalize
104
+ if scale is not None and normalize is False:
105
+ raise ValueError("normalize should be True if scale is passed")
106
+ if scale is None:
107
+ scale = 2 * math.pi
108
+ self.scale = scale
109
+
110
+ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
111
+ """Forward function.
112
+
113
+ Args:
114
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
115
+
116
+ Returns:
117
+ torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
118
+ """
119
+ x = tensor_list.tensors
120
+ mask = tensor_list.mask
121
+ assert mask is not None
122
+ not_mask = ~mask
123
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
124
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
125
+
126
+ if self.normalize:
127
+ eps = 1e-6
128
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
129
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
130
+
131
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
132
+ dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
133
+ pos_x = x_embed[:, :, :, None] / dim_tx
134
+
135
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
136
+ dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
137
+ pos_y = y_embed[:, :, :, None] / dim_ty
138
+
139
+ pos_x = torch.stack(
140
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
141
+ ).flatten(3)
142
+ pos_y = torch.stack(
143
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
144
+ ).flatten(3)
145
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
146
+
147
+ return pos
148
+
149
+
150
+ @POS_EMBEDDINGS.register_module()
151
+ class PositionEmbeddingLearned(nn.Module):
152
+ """Absolute pos embedding, learned.
153
+
154
+ Args:
155
+ num_pos_feats (int): The channel dimension of positional embeddings.
156
+ num_row (int): The number of rows of the input feature map.
157
+ num_col (int): The number of columns of the input feature map.
158
+ """
159
+
160
+ def __init__(
161
+ self, num_row: int = 50, num_col: int = 50, num_pos_feats: int = 256
162
+ ) -> None:
163
+ super().__init__()
164
+ self.row_embed = nn.Embedding(num_row, num_pos_feats)
165
+ self.col_embed = nn.Embedding(num_col, num_pos_feats)
166
+ self.reset_parameters()
167
+
168
+ def reset_parameters(self):
169
+ nn.init.uniform_(self.row_embed.weight)
170
+ nn.init.uniform_(self.col_embed.weight)
171
+
172
+ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
173
+ """Forward function.
174
+
175
+ Args:
176
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
177
+
178
+ Returns:
179
+ torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
180
+ """
181
+ x = tensor_list.tensors
182
+ h, w = x.shape[-2:]
183
+ i = torch.arange(w, device=x.device)
184
+ j = torch.arange(h, device=x.device)
185
+ x_emb = self.col_embed(i)
186
+ y_emb = self.row_embed(j)
187
+ pos = (
188
+ torch.cat(
189
+ [
190
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
191
+ y_emb.unsqueeze(1).repeat(1, w, 1),
192
+ ],
193
+ dim=-1,
194
+ )
195
+ .permute(2, 0, 1)
196
+ .unsqueeze(0)
197
+ .repeat(x.shape[0], 1, 1, 1)
198
+ )
199
+ return pos
200
+
201
+
202
+ def build_position_encoding(args):
203
+ N_steps = args.hidden_dim // 2
204
+ if args.position_embedding in ("v2", "sine"):
205
+ # TODO find a better way of exposing other arguments
206
+ position_embedding = PositionEmbeddingSineHW(
207
+ N_steps,
208
+ temperatureH=args.pe_temperatureH,
209
+ temperatureW=args.pe_temperatureW,
210
+ normalize=True,
211
+ )
212
+ elif args.position_embedding in ("v3", "learned"):
213
+ position_embedding = PositionEmbeddingLearned(N_steps)
214
+ else:
215
+ raise ValueError(f"not supported {args.position_embedding}")
216
+
217
+ return position_embedding
218
+
219
+
220
+ def clean_state_dict(state_dict):
221
+ new_state_dict = OrderedDict()
222
+ for k, v in state_dict.items():
223
+ if k[:7] == "module.":
224
+ k = k[7:] # remove `module.`
225
+ new_state_dict[k] = v
226
+ return new_state_dict
227
+
228
+
229
+ def get_activation_fn(activation: str, d_model: int = 256, batch_dim: int = 0):
230
+ """Return an activation function given a string
231
+
232
+ Args:
233
+ activation (str): activation function name
234
+ d_model (int, optional): d_model. Defaults to 256.
235
+ batch_dim (int, optional): batch dimension. Defaults to 0.
236
+
237
+ Returns:
238
+ F: activation function
239
+ """
240
+ if activation == "relu":
241
+ return F.relu
242
+ if activation == "gelu":
243
+ return F.gelu
244
+ if activation == "glu":
245
+ return F.glu
246
+ if activation == "prelu":
247
+ return nn.PReLU()
248
+ if activation == "selu":
249
+ return F.selu
250
+
251
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
252
+
253
+
254
+ def get_clones(module: nn.Module, N: int, layer_share: bool = False):
255
+ """Copy module N times
256
+
257
+ Args:
258
+ module (nn.Module): module to copy
259
+ N (int): number of copies
260
+ layer_share (bool, optional): share the same layer. If true, the modules will
261
+ share the same memory. Defaults to False.
262
+ """
263
+ if layer_share:
264
+ return nn.ModuleList([module for _ in range(N)])
265
+ else:
266
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
267
+
268
+
269
+ def inverse_sigmoid(x, eps=1e-3):
270
+ x = x.clamp(min=0, max=1)
271
+ x1 = x.clamp(min=eps)
272
+ x2 = (1 - x).clamp(min=eps)
273
+ return torch.log(x1 / x2)
274
+
275
+
276
+ def gen_sineembed_for_position(pos_tensor):
277
+ # n_query, bs, _ = pos_tensor.size()
278
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
279
+ scale = 2 * math.pi
280
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
281
+ dim_t = 10000 ** (2 * (dim_t // 2) / 128)
282
+ x_embed = pos_tensor[:, :, 0] * scale
283
+ y_embed = pos_tensor[:, :, 1] * scale
284
+ pos_x = x_embed[:, :, None] / dim_t
285
+ pos_y = y_embed[:, :, None] / dim_t
286
+ pos_x = torch.stack(
287
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
288
+ ).flatten(2)
289
+ pos_y = torch.stack(
290
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
291
+ ).flatten(2)
292
+ if pos_tensor.size(-1) == 2:
293
+ pos = torch.cat((pos_y, pos_x), dim=2)
294
+ elif pos_tensor.size(-1) == 4:
295
+ w_embed = pos_tensor[:, :, 2] * scale
296
+ pos_w = w_embed[:, :, None] / dim_t
297
+ pos_w = torch.stack(
298
+ (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
299
+ ).flatten(2)
300
+
301
+ h_embed = pos_tensor[:, :, 3] * scale
302
+ pos_h = h_embed[:, :, None] / dim_t
303
+ pos_h = torch.stack(
304
+ (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
305
+ ).flatten(2)
306
+
307
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
308
+ else:
309
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
310
+ return pos
311
+
312
+
313
+ def get_sine_pos_embed(
314
+ pos_tensor: torch.Tensor,
315
+ num_pos_feats: int = 128,
316
+ temperature: int = 10000,
317
+ exchange_xy: bool = True,
318
+ ):
319
+ """generate sine position embedding from a position tensor
320
+ Args:
321
+ pos_tensor (torch.Tensor): shape: [..., n].
322
+ num_pos_feats (int): projected shape for each float in the tensor.
323
+ temperature (int): temperature in the sine/cosine function.
324
+ exchange_xy (bool, optional): exchange pos x and pos y. \
325
+ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
326
+ Returns:
327
+ pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
328
+ """
329
+ scale = 2 * math.pi
330
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
331
+ dim_t = temperature ** (
332
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats
333
+ )
334
+
335
+ def sine_func(x: torch.Tensor):
336
+ sin_x = x * scale / dim_t
337
+ sin_x = torch.stack(
338
+ (sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3
339
+ ).flatten(2)
340
+ return sin_x
341
+
342
+ pos_res = [
343
+ sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)
344
+ ]
345
+ if exchange_xy:
346
+ pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
347
+ pos_res = torch.cat(pos_res, dim=-1)
348
+ return pos_res
349
+
350
+
351
+ def gen_encoder_output_proposals(
352
+ memory: torch.Tensor,
353
+ memory_padding_mask: torch.Tensor,
354
+ spatial_shapes: torch.Tensor,
355
+ learnedwh=None,
356
+ ):
357
+ """
358
+ Input:
359
+ - memory: bs, \sum{hw}, d_model
360
+ - memory_padding_mask: bs, \sum{hw}
361
+ - spatial_shapes: nlevel, 2
362
+ - learnedwh: 2
363
+ Output:
364
+ - output_memory: bs, \sum{hw}, d_model
365
+ - output_proposals: bs, \sum{hw}, 4
366
+ """
367
+ N_, S_, C_ = memory.shape
368
+ base_scale = 4.0
369
+ proposals = []
370
+ _cur = 0
371
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
372
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
373
+ N_, H_, W_, 1
374
+ )
375
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
376
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
377
+ grid_y, grid_x = torch.meshgrid(
378
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
379
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
380
+ )
381
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
382
+
383
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
384
+ N_, 1, 1, 2
385
+ )
386
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
387
+
388
+ if learnedwh is not None:
389
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
390
+ else:
391
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
392
+
393
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
394
+ proposals.append(proposal)
395
+ _cur += H_ * W_
396
+
397
+ output_proposals = torch.cat(proposals, 1)
398
+ output_proposals_valid = (
399
+ (output_proposals > 0.01) & (output_proposals < 0.99)
400
+ ).all(-1, keepdim=True)
401
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
402
+ output_proposals = output_proposals.masked_fill(
403
+ memory_padding_mask.unsqueeze(-1), float("inf")
404
+ )
405
+ output_proposals = output_proposals.masked_fill(
406
+ ~output_proposals_valid, float("inf")
407
+ )
408
+
409
+ output_memory = memory
410
+ output_memory = output_memory.masked_fill(
411
+ memory_padding_mask.unsqueeze(-1), float(0)
412
+ )
413
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
414
+
415
+ return output_memory, output_proposals
detect_tools/upn/ops/functions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn_func import MSDeformAttnFunction
10
+
detect_tools/upn/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.autograd import Function
16
+ from torch.autograd.function import once_differentiable
17
+
18
+ import MultiScaleDeformableAttention as MSDA
19
+
20
+
21
+ class MSDeformAttnFunction(Function):
22
+ @staticmethod
23
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24
+ ctx.im2col_step = im2col_step
25
+ output = MSDA.ms_deform_attn_forward(
26
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28
+ return output
29
+
30
+ @staticmethod
31
+ @once_differentiable
32
+ def backward(ctx, grad_output):
33
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34
+ grad_value, grad_sampling_loc, grad_attn_weight = \
35
+ MSDA.ms_deform_attn_backward(
36
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37
+
38
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39
+
40
+
41
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42
+ # for debug and test only,
43
+ # need to use cuda version instead
44
+ N_, S_, M_, D_ = value.shape
45
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47
+ sampling_grids = 2 * sampling_locations - 1
48
+ sampling_value_list = []
49
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54
+ # N_*M_, D_, Lq_, P_
55
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56
+ mode='bilinear', padding_mode='zeros', align_corners=False)
57
+ sampling_value_list.append(sampling_value_l_)
58
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61
+ return output.transpose(1, 2).contiguous()
detect_tools/upn/ops/modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn import MSDeformAttn
detect_tools/upn/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import warnings
14
+ import math, os
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.init import xavier_uniform_, constant_
20
+
21
+ try:
22
+ from ..functions import MSDeformAttnFunction
23
+ except:
24
+ warnings.warn("Failed to import MSDeformAttnFunction.")
25
+
26
+
27
+ def _is_power_of_2(n):
28
+ if (not isinstance(n, int)) or (n < 0):
29
+ raise ValueError(
30
+ "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
31
+ )
32
+ return (n & (n - 1) == 0) and n != 0
33
+
34
+
35
+ class MSDeformAttn(nn.Module):
36
+ def __init__(
37
+ self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False
38
+ ):
39
+ """
40
+ Multi-Scale Deformable Attention Module
41
+ :param d_model hidden dimension
42
+ :param n_levels number of feature levels
43
+ :param n_heads number of attention heads
44
+ :param n_points number of sampling points per attention head per feature level
45
+ """
46
+ super().__init__()
47
+ if d_model % n_heads != 0:
48
+ raise ValueError(
49
+ "d_model must be divisible by n_heads, but got {} and {}".format(
50
+ d_model, n_heads
51
+ )
52
+ )
53
+ _d_per_head = d_model // n_heads
54
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
55
+ if not _is_power_of_2(_d_per_head):
56
+ warnings.warn(
57
+ "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
58
+ "which is more efficient in our CUDA implementation."
59
+ )
60
+
61
+ self.im2col_step = 64
62
+
63
+ self.d_model = d_model
64
+ self.n_levels = n_levels
65
+ self.n_heads = n_heads
66
+ self.n_points = n_points
67
+
68
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
69
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
70
+ self.value_proj = nn.Linear(d_model, d_model)
71
+ self.output_proj = nn.Linear(d_model, d_model)
72
+
73
+ self.use_4D_normalizer = use_4D_normalizer
74
+
75
+ self._reset_parameters()
76
+
77
+ def _reset_parameters(self):
78
+ constant_(self.sampling_offsets.weight.data, 0.0)
79
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
80
+ 2.0 * math.pi / self.n_heads
81
+ )
82
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
83
+ grid_init = (
84
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
85
+ .view(self.n_heads, 1, 1, 2)
86
+ .repeat(1, self.n_levels, self.n_points, 1)
87
+ )
88
+ for i in range(self.n_points):
89
+ grid_init[:, :, i, :] *= i + 1
90
+ with torch.no_grad():
91
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
92
+ constant_(self.attention_weights.weight.data, 0.0)
93
+ constant_(self.attention_weights.bias.data, 0.0)
94
+ xavier_uniform_(self.value_proj.weight.data)
95
+ constant_(self.value_proj.bias.data, 0.0)
96
+ xavier_uniform_(self.output_proj.weight.data)
97
+ constant_(self.output_proj.bias.data, 0.0)
98
+
99
+ @torch.cuda.amp.autocast(enabled=False)
100
+ def forward(
101
+ self,
102
+ query,
103
+ reference_points,
104
+ input_flatten,
105
+ input_spatial_shapes,
106
+ input_level_start_index,
107
+ input_padding_mask=None,
108
+ ):
109
+ """
110
+ :param query (N, Length_{query}, C)
111
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
112
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
113
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
114
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
115
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
116
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
117
+
118
+ :return output (N, Length_{query}, C)
119
+ """
120
+ N, Len_q, _ = query.shape
121
+ N, Len_in, _ = input_flatten.shape
122
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
123
+
124
+ value = self.value_proj(input_flatten)
125
+ if input_padding_mask is not None:
126
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
127
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
128
+ sampling_offsets = self.sampling_offsets(query).view(
129
+ N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
130
+ )
131
+ attention_weights = self.attention_weights(query).view(
132
+ N, Len_q, self.n_heads, self.n_levels * self.n_points
133
+ )
134
+ attention_weights = F.softmax(attention_weights, -1).view(
135
+ N, Len_q, self.n_heads, self.n_levels, self.n_points
136
+ )
137
+ # N, Len_q, n_heads, n_levels, n_points, 2
138
+
139
+ # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
140
+ # import ipdb; ipdb.set_trace()
141
+
142
+ if reference_points.shape[-1] == 2:
143
+ offset_normalizer = torch.stack(
144
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
145
+ )
146
+ sampling_locations = (
147
+ reference_points[:, :, None, :, None, :]
148
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
149
+ )
150
+ elif reference_points.shape[-1] == 4:
151
+ if self.use_4D_normalizer:
152
+ offset_normalizer = torch.stack(
153
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
154
+ )
155
+ sampling_locations = (
156
+ reference_points[:, :, None, :, None, :2]
157
+ + sampling_offsets
158
+ / offset_normalizer[None, None, None, :, None, :]
159
+ * reference_points[:, :, None, :, None, 2:]
160
+ * 0.5
161
+ )
162
+ else:
163
+ sampling_locations = (
164
+ reference_points[:, :, None, :, None, :2]
165
+ + sampling_offsets
166
+ / self.n_points
167
+ * reference_points[:, :, None, :, None, 2:]
168
+ * 0.5
169
+ )
170
+ else:
171
+ raise ValueError(
172
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
173
+ reference_points.shape[-1]
174
+ )
175
+ )
176
+
177
+ # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
178
+ # import ipdb; ipdb.set_trace()
179
+
180
+ # for amp
181
+ if value.dtype == torch.float16:
182
+ # for mixed precision
183
+ output = MSDeformAttnFunction.apply(
184
+ value.to(torch.float32),
185
+ input_spatial_shapes,
186
+ input_level_start_index,
187
+ sampling_locations.to(torch.float32),
188
+ attention_weights,
189
+ self.im2col_step,
190
+ )
191
+ output = output.to(torch.float16)
192
+ output = self.output_proj(output)
193
+ return output
194
+
195
+ output = MSDeformAttnFunction.apply(
196
+ value,
197
+ input_spatial_shapes,
198
+ input_level_start_index,
199
+ sampling_locations,
200
+ attention_weights,
201
+ self.im2col_step,
202
+ )
203
+ output = self.output_proj(output)
204
+ return output
detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import warnings
14
+ import math, os
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.init import xavier_uniform_, constant_
20
+
21
+ try:
22
+ from ..functions import MSDeformAttnFunction
23
+ except:
24
+ warnings.warn('Failed to import MSDeformAttnFunction.')
25
+
26
+
27
+ def _is_power_of_2(n):
28
+ if (not isinstance(n, int)) or (n < 0):
29
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
30
+ return (n & (n-1) == 0) and n != 0
31
+
32
+
33
+ class MSDeformAttn(nn.Module):
34
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
35
+ """
36
+ Multi-Scale Deformable Attention Module
37
+ :param d_model hidden dimension
38
+ :param n_levels number of feature levels
39
+ :param n_heads number of attention heads
40
+ :param n_points number of sampling points per attention head per feature level
41
+ """
42
+ super().__init__()
43
+ if d_model % n_heads != 0:
44
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
45
+ _d_per_head = d_model // n_heads
46
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
47
+ if not _is_power_of_2(_d_per_head):
48
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
49
+ "which is more efficient in our CUDA implementation.")
50
+
51
+ self.im2col_step = 64
52
+
53
+ self.d_model = d_model
54
+ self.n_levels = n_levels
55
+ self.n_heads = n_heads
56
+ self.n_points = n_points
57
+
58
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
59
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
60
+ self.value_proj = nn.Linear(d_model, d_model)
61
+ self.output_proj = nn.Linear(d_model, d_model)
62
+
63
+ self.use_4D_normalizer = use_4D_normalizer
64
+
65
+ self._reset_parameters()
66
+
67
+ def _reset_parameters(self):
68
+ constant_(self.sampling_offsets.weight.data, 0.)
69
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
70
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
71
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
72
+ for i in range(self.n_points):
73
+ grid_init[:, :, i, :] *= i + 1
74
+ with torch.no_grad():
75
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
76
+ constant_(self.attention_weights.weight.data, 0.)
77
+ constant_(self.attention_weights.bias.data, 0.)
78
+ xavier_uniform_(self.value_proj.weight.data)
79
+ constant_(self.value_proj.bias.data, 0.)
80
+ xavier_uniform_(self.output_proj.weight.data)
81
+ constant_(self.output_proj.bias.data, 0.)
82
+
83
+ def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
84
+ """
85
+ :param query (N, Length_{query}, C)
86
+ :param key (N, 1, C)
87
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
88
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
89
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
90
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
91
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
92
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
93
+
94
+ :return output (N, Length_{query}, C)
95
+ """
96
+ N, Len_q, _ = query.shape
97
+ N, Len_in, _ = input_flatten.shape
98
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
99
+
100
+ value = self.value_proj(input_flatten)
101
+ if input_padding_mask is not None:
102
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
103
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
104
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
105
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
106
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
107
+ # N, Len_q, n_heads, n_levels, n_points, 2
108
+
109
+ # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
110
+ # import ipdb; ipdb.set_trace()
111
+
112
+ if reference_points.shape[-1] == 2:
113
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
114
+ sampling_locations = reference_points[:, :, None, :, None, :] \
115
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
116
+ elif reference_points.shape[-1] == 4:
117
+ if self.use_4D_normalizer:
118
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
119
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
120
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
121
+ else:
122
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
123
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
124
+ else:
125
+ raise ValueError(
126
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
127
+ output = MSDeformAttnFunction.apply(
128
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
129
+ output = self.output_proj(output)
130
+ return output
detect_tools/upn/ops/setup.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ import os
10
+ import glob
11
+
12
+ import torch
13
+
14
+ from torch.utils.cpp_extension import CUDA_HOME
15
+ from torch.utils.cpp_extension import CppExtension
16
+ from torch.utils.cpp_extension import CUDAExtension
17
+
18
+ from setuptools import find_packages
19
+ from setuptools import setup
20
+
21
+ requirements = ["torch", "torchvision"]
22
+
23
+ def get_extensions():
24
+ this_dir = os.path.dirname(os.path.abspath(__file__))
25
+ extensions_dir = os.path.join(this_dir, "src")
26
+
27
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30
+
31
+ sources = main_file + source_cpu
32
+ extension = CppExtension
33
+ extra_compile_args = {"cxx": []}
34
+ define_macros = []
35
+
36
+ # import ipdb; ipdb.set_trace()
37
+
38
+ if torch.cuda.is_available() and CUDA_HOME is not None:
39
+ extension = CUDAExtension
40
+ sources += source_cuda
41
+ define_macros += [("WITH_CUDA", None)]
42
+ extra_compile_args["nvcc"] = [
43
+ "-DCUDA_HAS_FP16=1",
44
+ "-D__CUDA_NO_HALF_OPERATORS__",
45
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
46
+ "-D__CUDA_NO_HALF2_OPERATORS__",
47
+ ]
48
+ else:
49
+ raise NotImplementedError('Cuda is not availabel')
50
+
51
+ sources = [os.path.join(extensions_dir, s) for s in sources]
52
+ include_dirs = [extensions_dir]
53
+ ext_modules = [
54
+ extension(
55
+ "MultiScaleDeformableAttention",
56
+ sources,
57
+ include_dirs=include_dirs,
58
+ define_macros=define_macros,
59
+ extra_compile_args=extra_compile_args,
60
+ )
61
+ ]
62
+ return ext_modules
63
+
64
+ setup(
65
+ name="MultiScaleDeformableAttention",
66
+ version="1.0",
67
+ author="Weijie Su",
68
+ url="https://github.com/fundamentalvision/Deformable-DETR",
69
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
70
+ packages=find_packages(exclude=("configs", "tests",)),
71
+ ext_modules=get_extensions(),
72
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
73
+ )
detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
41
+
detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+
detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "cuda/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+
20
+ at::Tensor ms_deform_attn_cuda_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step)
27
+ {
28
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33
+
34
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
35
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
36
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
37
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
38
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
39
+
40
+ const int batch = value.size(0);
41
+ const int spatial_size = value.size(1);
42
+ const int num_heads = value.size(2);
43
+ const int channels = value.size(3);
44
+
45
+ const int num_levels = spatial_shapes.size(0);
46
+
47
+ const int num_query = sampling_loc.size(1);
48
+ const int num_point = sampling_loc.size(4);
49
+
50
+ const int im2col_step_ = std::min(batch, im2col_step);
51
+
52
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53
+
54
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55
+
56
+ const int batch_n = im2col_step_;
57
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58
+ auto per_value_size = spatial_size * num_heads * channels;
59
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61
+ for (int n = 0; n < batch/im2col_step_; ++n)
62
+ {
63
+ auto columns = output_n.select(0, n);
64
+ AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
65
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
67
+ spatial_shapes.data_ptr<int64_t>(),
68
+ level_start_index.data_ptr<int64_t>(),
69
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
70
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
71
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72
+ columns.data_ptr<scalar_t>());
73
+
74
+ }));
75
+ }
76
+
77
+ output = output.view({batch, num_query, num_heads*channels});
78
+
79
+ return output;
80
+ }
81
+
82
+
83
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
84
+ const at::Tensor &value,
85
+ const at::Tensor &spatial_shapes,
86
+ const at::Tensor &level_start_index,
87
+ const at::Tensor &sampling_loc,
88
+ const at::Tensor &attn_weight,
89
+ const at::Tensor &grad_output,
90
+ const int im2col_step)
91
+ {
92
+
93
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99
+
100
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
101
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
102
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
103
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
104
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
105
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
106
+
107
+ const int batch = value.size(0);
108
+ const int spatial_size = value.size(1);
109
+ const int num_heads = value.size(2);
110
+ const int channels = value.size(3);
111
+
112
+ const int num_levels = spatial_shapes.size(0);
113
+
114
+ const int num_query = sampling_loc.size(1);
115
+ const int num_point = sampling_loc.size(4);
116
+
117
+ const int im2col_step_ = std::min(batch, im2col_step);
118
+
119
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120
+
121
+ auto grad_value = at::zeros_like(value);
122
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
123
+ auto grad_attn_weight = at::zeros_like(attn_weight);
124
+
125
+ const int batch_n = im2col_step_;
126
+ auto per_value_size = spatial_size * num_heads * channels;
127
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130
+
131
+ for (int n = 0; n < batch/im2col_step_; ++n)
132
+ {
133
+ auto grad_output_g = grad_output_n.select(0, n);
134
+ AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
135
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136
+ grad_output_g.data_ptr<scalar_t>(),
137
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
138
+ spatial_shapes.data_ptr<int64_t>(),
139
+ level_start_index.data_ptr<int64_t>(),
140
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
141
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
142
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143
+ grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
144
+ grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
145
+ grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
146
+
147
+ }));
148
+ }
149
+
150
+ return {
151
+ grad_value, grad_sampling_loc, grad_attn_weight
152
+ };
153
+ }
detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const at::Tensor &grad_output,
29
+ const int im2col_step);
30
+
detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
detect_tools/upn/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
62
+
detect_tools/upn/ops/src/vision.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include "ms_deform_attn.h"
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16
+ }
detect_tools/upn/ops/test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import time
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.autograd import gradcheck
17
+
18
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19
+
20
+
21
+ N, M, D = 1, 2, 2
22
+ Lq, L, P = 2, 2, 2
23
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25
+ S = sum([(H*W).item() for H, W in shapes])
26
+
27
+
28
+ torch.manual_seed(3)
29
+
30
+
31
+ @torch.no_grad()
32
+ def check_forward_equal_with_pytorch_double():
33
+ value = torch.rand(N, S, M, D).cuda() * 0.01
34
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37
+ im2col_step = 2
38
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40
+ fwdok = torch.allclose(output_cuda, output_pytorch)
41
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
42
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43
+
44
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45
+
46
+
47
+ @torch.no_grad()
48
+ def check_forward_equal_with_pytorch_float():
49
+ value = torch.rand(N, S, M, D).cuda() * 0.01
50
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53
+ im2col_step = 2
54
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
58
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59
+
60
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61
+
62
+
63
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64
+
65
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
66
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69
+ im2col_step = 2
70
+ func = MSDeformAttnFunction.apply
71
+
72
+ value.requires_grad = grad_value
73
+ sampling_locations.requires_grad = grad_sampling_loc
74
+ attention_weights.requires_grad = grad_attn_weight
75
+
76
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77
+
78
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
79
+
80
+
81
+ if __name__ == '__main__':
82
+ check_forward_equal_with_pytorch_double()
83
+ check_forward_equal_with_pytorch_float()
84
+
85
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86
+ check_gradient_numerical(channels, True, True, True)
87
+
88
+
89
+
detect_tools/upn/transforms/transform.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torchvision.transforms.functional as F
4
+
5
+
6
+ def resize(image, target, size, max_size=None):
7
+ # size can be min_size (scalar) or (w, h) tuple
8
+
9
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
10
+ w, h = image_size
11
+ if max_size is not None:
12
+ min_original_size = float(min((w, h)))
13
+ max_original_size = float(max((w, h)))
14
+ if max_original_size / min_original_size * size > max_size:
15
+ size = int(round(max_size * min_original_size / max_original_size))
16
+
17
+ if (w <= h and w == size) or (h <= w and h == size):
18
+ return (h, w)
19
+
20
+ if w < h:
21
+ ow = size
22
+ oh = int(size * h / w)
23
+ else:
24
+ oh = size
25
+ ow = int(size * w / h)
26
+
27
+ return (oh, ow)
28
+
29
+ def get_size(image_size, size, max_size=None):
30
+ if isinstance(size, (list, tuple)):
31
+ return size[::-1]
32
+ else:
33
+ return get_size_with_aspect_ratio(image_size, size, max_size)
34
+
35
+ size = get_size(image.size, size, max_size)
36
+ rescaled_image = F.resize(image, size)
37
+
38
+ if target is None:
39
+ return rescaled_image, None
40
+
41
+ ratios = tuple(
42
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
43
+ )
44
+ ratio_width, ratio_height = ratios
45
+
46
+ target = target.copy()
47
+ if "exampler_box" in target:
48
+ boxes = target["exampler_box"]
49
+ if isinstance(boxes, torch.Tensor):
50
+ scaled_boxes = boxes * torch.as_tensor(
51
+ [ratio_width, ratio_height, ratio_width, ratio_height]
52
+ )
53
+ target["exampler_box"] = scaled_boxes
54
+ elif isinstance(boxes, dict):
55
+ for k, v in boxes.items():
56
+ scaled_boxes = v * torch.as_tensor(
57
+ [ratio_width, ratio_height, ratio_width, ratio_height]
58
+ )
59
+ target["exampler_box"][k] = scaled_boxes
60
+
61
+ if "demo_pos_exampler_box" in target:
62
+ boxes = target["demo_pos_exampler_box"]
63
+ scaled_boxes = boxes * torch.as_tensor(
64
+ [ratio_width, ratio_height, ratio_width, ratio_height]
65
+ )
66
+ target["demo_pos_exampler_box"] = scaled_boxes
67
+
68
+ if "demo_neg_exampler_box" in target:
69
+ boxes = target["demo_neg_exampler_box"]
70
+ scaled_boxes = boxes * torch.as_tensor(
71
+ [ratio_width, ratio_height, ratio_width, ratio_height]
72
+ )
73
+ target["demo_neg_exampler_box"] = scaled_boxes
74
+
75
+ if "boxes" in target:
76
+ boxes = target["boxes"]
77
+ scaled_boxes = boxes * torch.as_tensor(
78
+ [ratio_width, ratio_height, ratio_width, ratio_height]
79
+ )
80
+ target["boxes"] = scaled_boxes
81
+
82
+ if "area" in target:
83
+ area = target["area"]
84
+ scaled_area = area * (ratio_width * ratio_height)
85
+ target["area"] = scaled_area
86
+
87
+ h, w = size
88
+ target["size"] = torch.tensor([h, w])
89
+
90
+ return rescaled_image, target
91
+
92
+
93
+ class RandomResize(object):
94
+
95
+ def __init__(self, sizes, max_size=None):
96
+ assert isinstance(sizes, (list, tuple))
97
+ self.sizes = sizes
98
+ self.max_size = max_size
99
+
100
+ def __call__(self, img, target=None):
101
+ size = random.choice(self.sizes)
102
+ return resize(img, target, size, self.max_size)
103
+
104
+
105
+ class ToTensor(object):
106
+
107
+ def __call__(self, img, target):
108
+ return F.to_tensor(img), target
109
+
110
+
111
+ class Normalize(object):
112
+
113
+ def __init__(self, mean, std):
114
+ self.mean = mean
115
+ self.std = std
116
+
117
+ def __call__(self, image, target=None):
118
+ image = F.normalize(image, mean=self.mean, std=self.std)
119
+ if target is None:
120
+ return image, None
121
+ target = target.copy()
122
+ h, w = image.shape[-2:]
123
+ return image, target
124
+
125
+
126
+ class Compose(object):
127
+
128
+ def __init__(self, transforms):
129
+ self.transforms = transforms
130
+
131
+ def __call__(self, image, target):
132
+ for t in self.transforms:
133
+ image, target = t(image, target)
134
+ return image, target
135
+
136
+ def __repr__(self):
137
+ format_string = self.__class__.__name__ + "("
138
+ for t in self.transforms:
139
+ format_string += "\n"
140
+ format_string += " {0}".format(t)
141
+ format_string += "\n)"
142
+ return format_string
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ transformers==4.50.1
4
+ timm==1.0.9
5
+ accelerate==1.4.0
6
+ gradio
7
+ mmengine==0.8.2
8
+ einops
9
+ flash-attn
10
+ scikit-image
run.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "--- install dependencies ---"
2
+ pip install -r requirements.txt
3
+
4
+ # 2. 检查基础包是否安装成功
5
+ if [ $? -ne 0 ]; then
6
+ echo "install dependencies failed, exit."
7
+ exit 1
8
+ fi
9
+
10
+ echo "--- install dependencies successfully ---"
11
+ echo "--- compile and install local 'ops' package ---"
12
+
13
+ pip install --no-build-isolation -e ./VLM-FO1/detect_tools/upn/ops
14
+
15
+ if [ $? -ne 0 ]; then
16
+ echo "compile and install local 'ops' package failed, exit."
17
+ exit 1
18
+ fi
19
+
20
+ echo "--- compile and install local 'ops' package successfully ---"
21
+ echo "--- launch Gradio application ---"
22
+
23
+ python demo/gradio_demo.py
vlm_fo1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
vlm_fo1/constants.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOGDIR = "."
2
+
3
+ global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
4
+ # Model Constants
5
+ IGNORE_INDEX = -100
6
+ IMAGE_TOKEN_INDEX = -200 #151656 #151655 #-200
7
+ DEFAULT_IMAGE_TOKEN = "<image>"
8
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
9
+ DEFAULT_IM_START_TOKEN = "<im_start>"
10
+ DEFAULT_IM_END_TOKEN = "<im_end>"
11
+
12
+ # For Qwen2_5_VL
13
+ QWEN2_5_VL_IMAGE_TOKEN = "<|image_pad|>"
14
+ QWEN2_5_VL_IMAGE_TOKEN_INDEX = 151655
15
+
16
+ # For regions
17
+ DEFAULT_REGION_TOKEN = "<region<i>>"
18
+ DEFAULT_REGION_FEATURE_TOKEN = "<regionfeat>"
19
+ DEFAULT_REGION_INDEX = -300 #151654 #151654 #-300
20
+
21
+ # For Grounding
22
+ DEFAULT_GROUNDING_START = "<ground>"
23
+ DEFAULT_GROUNDING_END = "</ground>"
24
+ DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
25
+ DEFAULT_GROUNDING_OBJECTS_END = "</objects>"
26
+
27
+ # For Think
28
+ DEFAULT_THINK_START = "<think>"
29
+ DEFAULT_THINK_END = "</think>"
vlm_fo1/mm_utils.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageDraw, ImageOps
3
+ from io import BytesIO
4
+ import base64
5
+ import re
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+ from vlm_fo1.constants import IMAGE_TOKEN_INDEX, DEFAULT_REGION_INDEX
9
+ import requests
10
+ from vlm_fo1.constants import (
11
+ IMAGE_TOKEN_INDEX,
12
+ DEFAULT_IMAGE_TOKEN,
13
+ DEFAULT_IM_START_TOKEN,
14
+ DEFAULT_IM_END_TOKEN,
15
+ IGNORE_INDEX,
16
+ DEFAULT_REGION_TOKEN,
17
+ DEFAULT_REGION_FEATURE_TOKEN
18
+ )
19
+ import torch
20
+ from transformers import TextStreamer
21
+ import random
22
+ import re
23
+ from typing import List, Tuple
24
+ import io
25
+ import base64
26
+
27
+
28
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
29
+ """
30
+ Tokenizes prompts containing <image> or <image_0>... special tokens.
31
+
32
+ If the prompt uses <image_0>, <image_1>, ..., each is replaced with a placeholder index (-200).
33
+ If the prompt uses <image>, it is replaced with image_token_index.
34
+
35
+ Args:
36
+ prompt (str): The prompt potentially containing image tokens.
37
+ tokenizer: The tokenizer object.
38
+ image_token_index (int): Token id to use when encountering <image> token.
39
+ return_tensors (Optional[str]): If 'pt', return a torch tensor.
40
+
41
+ Returns:
42
+ List[int] or torch.Tensor: The tokenized input with image token indices inserted appropriately.
43
+ """
44
+ if "<image_0>" in prompt:
45
+ # Case: prompt contains indexed image tokens like <image_0>, <image_1>, etc.
46
+ image_token_pattern = re.compile(r"<image_(\d+)>")
47
+ prompt_chunks = re.split(r'<image_[0-9]+>', prompt)
48
+ image_tags = image_token_pattern.findall(prompt)
49
+
50
+ input_ids = []
51
+ for i, chunk in enumerate(prompt_chunks):
52
+ input_ids.extend(tokenizer(chunk).input_ids)
53
+ if i < len(image_tags):
54
+ # Insert placeholder where <image_n> token was.
55
+ input_ids.append(-200)
56
+ else:
57
+ # Case: prompt contains plain <image> tokens.
58
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
59
+
60
+ def insert_separator(X, sep):
61
+ # Helper function to insert a separator token between chunks.
62
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
63
+
64
+ input_ids = []
65
+ offset = 0
66
+ # If first chunk starts with <bos> token, make sure to keep it only once.
67
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
68
+ offset = 1
69
+ input_ids.append(prompt_chunks[0][0])
70
+
71
+ # Insert image_token_index between chunks.
72
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
73
+ input_ids.extend(x[offset:])
74
+ # Optionally convert output to PyTorch tensor.
75
+ if return_tensors is not None:
76
+ if return_tensors == 'pt':
77
+ return torch.tensor(input_ids, dtype=torch.long)
78
+ else:
79
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
80
+
81
+ return input_ids
82
+
83
+ def tokenizer_image_region_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, region_token_index=DEFAULT_REGION_INDEX, return_tensors=None):
84
+ """
85
+ Tokenizes prompts containing both <image> and <regionfeat> delimiters, inserting specified token indices.
86
+
87
+ Each <image> chunk is split, and within that chunk, <regionfeat> locations receive region_token_index.
88
+
89
+ Args:
90
+ prompt (str): The prompt with <image> and <regionfeat> delimiters.
91
+ tokenizer: The tokenizer object.
92
+ image_token_index (int): Insert this at <image> splits.
93
+ region_token_index (int): Insert this at <regionfeat> splits.
94
+ return_tensors (Optional[str]): If 'pt', return torch tensor.
95
+
96
+ Returns:
97
+ List[int] or torch.Tensor: The tokenized input with region/image tokens placed.
98
+ """
99
+ # Split by <image> tags first.
100
+ image_chunks = prompt.split('<image>')
101
+
102
+ prompt_chunks = []
103
+ for chunk in image_chunks:
104
+ # Split each image chunk by <regionfeat>.
105
+ obj_chunks = chunk.split('<regionfeat>')
106
+ # Tokenize each subchunk.
107
+ token_chunks = [tokenizer(c).input_ids for c in obj_chunks]
108
+ prompt_chunks.append(token_chunks)
109
+
110
+ input_ids = []
111
+ offset = 0
112
+
113
+ # If first chunk starts with <bos> token, include only once.
114
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and len(prompt_chunks[0][0]) > 0 and prompt_chunks[0][0][0] == tokenizer.bos_token_id:
115
+ offset = 1
116
+ input_ids.append(prompt_chunks[0][0][0])
117
+
118
+ # Stitch together all chunks with region/image tokens at appropriate locations.
119
+ for i, chunk_group in enumerate(prompt_chunks):
120
+ if len(chunk_group) > 0:
121
+ input_ids.extend(chunk_group[0][offset:])
122
+ for chunk in chunk_group[1:]:
123
+ input_ids.append(region_token_index)
124
+ input_ids.extend(chunk)
125
+ # Insert <image> token except after the last image chunk.
126
+ if i < len(prompt_chunks) - 1:
127
+ input_ids.append(image_token_index)
128
+ # Optionally convert to PyTorch tensor.
129
+ if return_tensors is not None:
130
+ if return_tensors == 'pt':
131
+ return torch.tensor(input_ids, dtype=torch.long)
132
+ else:
133
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
134
+
135
+ return input_ids
136
+
137
+ class KeywordsStoppingCriteria(StoppingCriteria):
138
+ """
139
+ Implements custom stopping criteria for generation based on keywords:
140
+ If the generated output contains any of the keywords, generation stops.
141
+ """
142
+ def __init__(self, keywords, tokenizer, input_ids):
143
+ self.keywords = keywords
144
+ self.keyword_ids = []
145
+ self.max_keyword_len = 0
146
+ for keyword in keywords:
147
+ cur_keyword_ids = tokenizer(keyword).input_ids
148
+ # Remove BOS if present except for single token
149
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
150
+ cur_keyword_ids = cur_keyword_ids[1:]
151
+ if len(cur_keyword_ids) > self.max_keyword_len:
152
+ self.max_keyword_len = len(cur_keyword_ids)
153
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
154
+ self.tokenizer = tokenizer
155
+ # Track the generation start length
156
+ self.start_len = input_ids.shape[1]
157
+
158
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
159
+ """
160
+ Checks if a keyword exists in the latest generated output ids for a single batch element.
161
+ """
162
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
163
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
164
+ for keyword_id in self.keyword_ids:
165
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
166
+ if torch.equal(truncated_output_ids, keyword_id):
167
+ return True
168
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
169
+ for keyword in self.keywords:
170
+ if keyword in outputs:
171
+ return True
172
+ return False
173
+
174
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
175
+ """
176
+ Checks for keywords in each batch item; stops when all have satisfied the keyword condition.
177
+ """
178
+ outputs = []
179
+ for i in range(output_ids.shape[0]):
180
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
181
+ return all(outputs)
182
+
183
+ def load_image(image_file):
184
+ """
185
+ Loads an image from a local path, base64 string, URL, or PIL.Image.
186
+
187
+ If the input image is smaller than 28x28, it will be resized to at least that size.
188
+
189
+ Args:
190
+ image_file (str or PIL.Image.Image): Image source.
191
+
192
+ Returns:
193
+ PIL.Image.Image: Loaded image in RGB mode, at least 28x28 in size.
194
+ """
195
+ if isinstance(image_file, Image.Image):
196
+ image = image_file.convert("RGB")
197
+ # Case: load from URL
198
+ elif image_file.startswith("http") or image_file.startswith("https"):
199
+ response = requests.get(image_file)
200
+ image = Image.open(BytesIO(response.content)).convert("RGB")
201
+ # Case: load from base64-encoded string
202
+ elif image_file.startswith("data:image/"):
203
+ image = image_file.replace("data:image/jpeg;base64,", "")
204
+ image_data = base64.b64decode(image)
205
+ image = Image.open(BytesIO(image_data)).convert("RGB")
206
+ elif isinstance(image_file, str):
207
+ # Case: load from local file path
208
+ image = Image.open(image_file).convert("RGB")
209
+ else:
210
+ raise ValueError(f"Unsupported image type: {type(image_file)}")
211
+
212
+ # Ensure minimum size 28x28
213
+ if image.width < 28 or image.height < 28:
214
+ image = image.resize((max(28, image.width), max(28, image.height)))
215
+ return image
216
+
217
+ def image_to_base64(img_pil):
218
+ """
219
+ Encodes a PIL Image as JPEG in base64 format.
220
+
221
+ Args:
222
+ img_pil (PIL.Image.Image): Source image.
223
+
224
+ Returns:
225
+ str: base64-encoded JPEG image string.
226
+ """
227
+ with io.BytesIO() as buffer:
228
+ img_pil.save(buffer, format="JPEG")
229
+ base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
230
+ return base64_image
231
+
232
+ def draw_bboxes_and_save(
233
+ image: Image.Image,
234
+ fo1_bboxes: dict = {},
235
+ detection_bboxes: List[Tuple[int, int, int, int]] = [],
236
+ output_path: str = 'output.jpg',
237
+ color: str = 'red',
238
+ total_color: str = 'green',
239
+ width: int = 2
240
+ ) -> None:
241
+ """
242
+ Draws bounding boxes (both ground-truth/proposed and detection) on a PIL image and saves result.
243
+
244
+ Args:
245
+ image (PIL.Image.Image): Input PIL image object.
246
+ fo1_bboxes (dict): Label -> List[bbox] mapping for annotation bboxes.
247
+ detection_bboxes (List[Tuple]): List of detection bounding boxes; each bbox is (x_min, y_min, x_max, y_max).
248
+ output_path (str): Path to save the output image.
249
+ color (str): Color for fo1_bboxes.
250
+ total_color (str): Color for detection_bboxes.
251
+ width (int): Rectangle outline width.
252
+
253
+ Returns:
254
+ None
255
+ """
256
+ draw = ImageDraw.Draw(image)
257
+
258
+ # Draw detection boxes with `total_color`
259
+ for bbox in detection_bboxes:
260
+ if len(bbox) != 4:
261
+ print(f"Warning: skip the invalid bbox {bbox}")
262
+ continue
263
+ shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
264
+ draw.rectangle(shape, outline=total_color, width=width)
265
+
266
+ # Draw annotated bboxes with labels and `color`
267
+ for bbox_label, bbox_list in fo1_bboxes.items():
268
+ for bbox in bbox_list:
269
+ if len(bbox) != 4:
270
+ print(f"Warning: skip the invalid bbox {bbox}")
271
+ continue
272
+ shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
273
+ draw.rectangle(shape, outline=color, width=width)
274
+ draw.text((bbox[0], bbox[1]), bbox_label, fill=color)
275
+
276
+ # Save output image (catching common IO exceptions).
277
+ try:
278
+ image.save(output_path)
279
+ print(f"The image has been successfully saved to: {output_path}")
280
+ except IOError as e:
281
+ print(f"Error: failed to save the image to {output_path}. Reason: {e}")
282
+
283
+ def adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w):
284
+ """
285
+ Adjusts bounding boxes from original image size to resized image size, compensating for scaling.
286
+
287
+ Args:
288
+ bbox_list (List[List[float]]): List of original boxes [x1, y1, x2, y2].
289
+ original_h (int): Original image height.
290
+ original_w (int): Original image width.
291
+ resize_h (int): Resized image height.
292
+ resize_w (int): Resized image width.
293
+
294
+ Returns:
295
+ List[List[float]]: Bounding boxes transformed to resized image coordinates.
296
+ """
297
+ output_list = []
298
+ def adjust_bbox_range(bbox, width, height):
299
+ # Ensure all coordinates are within the original image border.
300
+ x1, y1, x2, y2 = bbox
301
+ x1 = max(0, min(width, x1))
302
+ y1 = max(0, min(height, y1))
303
+ x2 = max(0, min(width, x2))
304
+ y2 = max(0, min(height, y2))
305
+ return [x1, y1, x2, y2]
306
+
307
+ for bbox in bbox_list:
308
+ bbox = adjust_bbox_range(bbox, original_w, original_h)
309
+ bbox[0] = bbox[0] * resize_w / original_w
310
+ bbox[1] = bbox[1] * resize_h / original_h
311
+ bbox[2] = bbox[2] * resize_w / original_w
312
+ bbox[3] = bbox[3] * resize_h / original_h
313
+ output_list.append(bbox)
314
+ return output_list
315
+
316
+ def extract_predictions_to_bboxes(prediction: str, bbox_list):
317
+ """
318
+ Parse prediction string in the expected format and map each ground label
319
+ to its corresponding bounding boxes using bbox_list.
320
+
321
+ Args:
322
+ prediction (str): Model output string with <ground>...<objects>... markup.
323
+ bbox_list (List[List[float]]): Full list of predicted or reference bounding boxes.
324
+
325
+ Returns:
326
+ dict: label -> list of bboxes
327
+ """
328
+ label_to_indexes = {}
329
+ label_to_bboxes = {}
330
+
331
+ match_pattern = r"<ground>(.*?)<\/ground><objects>(.*?)<\/objects>"
332
+ matches = re.findall(match_pattern, prediction)
333
+
334
+ for label_text, indexes in matches:
335
+ label_text = label_text.strip()
336
+ indexes_tags = re.findall(r"<region\d+>", indexes)
337
+ region_indexes = set([int(index.split("<region")[-1].split(">")[0]) for index in indexes_tags])
338
+ if label_text not in label_to_indexes:
339
+ label_to_indexes[label_text] = region_indexes
340
+ else:
341
+ label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes
342
+
343
+ for label, indexes in label_to_indexes.items():
344
+ label_to_bboxes[label] = [bbox_list[index] for index in indexes]
345
+
346
+ return label_to_bboxes
347
+
348
+ def extract_predictions_to_indexes(prediction: str):
349
+ """
350
+ Parse prediction string, returning label -> set-of-indexes mapping.
351
+
352
+ Args:
353
+ prediction (str): Model prediction output.
354
+
355
+ Returns:
356
+ dict: label -> set(int)
357
+ """
358
+ label_to_indexes = {}
359
+ match_pattern = r"<ground>(.*?)<\/ground><objects>(.*?)<\/objects>"
360
+ matches = re.findall(match_pattern, prediction)
361
+
362
+ for label_text, indexes in matches:
363
+ label_text = label_text.strip()
364
+ indexes_tags = re.findall(r"<region\d+>", indexes)
365
+ region_indexes = set([int(index.split("<region")[-1].split(">")[0]) for index in indexes_tags])
366
+ if label_text not in label_to_indexes:
367
+ label_to_indexes[label_text] = region_indexes
368
+ else:
369
+ label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes
370
+
371
+ return label_to_indexes
372
+
373
+ def resize_shortest_edge_images_and_bboxes(
374
+ image_list: List[Image.Image],
375
+ bbox_lists: List,
376
+ candidate_sizes: List[int] = [],
377
+ max_size: int = 2048
378
+ ):
379
+ """
380
+ Randomly selects a size for the shortest edge, and proportionally resizes both images and bounding boxes.
381
+
382
+ The function maintains the image aspect ratio and ensures that the resized dimensions do not exceed the specified max_size.
383
+ Bounding boxes are transformed accordingly.
384
+
385
+ Args:
386
+ image_list (List[Image.Image]): A list of PIL Image objects.
387
+ bbox_lists (List[List[List[float]]]): A list of lists of bounding boxes per image.
388
+ candidate_sizes (List[int]): Optional list of sizes to choose the target short edge from.
389
+ max_size (int): Maximum allowed long edge after resizing.
390
+
391
+ Returns:
392
+ Tuple[List[Image.Image], List[List[List[float]]]]:
393
+ ([resized_image1, ...], [bbox_list1, ...]) - Possibly shape will match original (see below)
394
+
395
+ Raises:
396
+ ValueError: on input list length mismatch or emptiness.
397
+ """
398
+ bbox_tensor = torch.tensor(bbox_lists)
399
+ # Normalize input: wrap bbox_lists into list-of-list, if needed.
400
+ if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4:
401
+ bbox_lists = [bbox_lists]
402
+
403
+ if not image_list or not bbox_lists:
404
+ raise ValueError("Input lists cannot be empty.")
405
+ if len(image_list) != len(bbox_lists):
406
+ raise ValueError("The lengths of the image list and the bounding box list must be the same.")
407
+
408
+ # Randomly select short edge size (if given candidate sizes)
409
+ if len(candidate_sizes) > 0:
410
+ target_size = random.choice(candidate_sizes)
411
+ else:
412
+ target_size = None
413
+
414
+ resized_images = []
415
+ transformed_bbox_lists = []
416
+
417
+ # Process each image and its corresponding bbox list
418
+ for img, bboxes in zip(image_list, bbox_lists):
419
+ original_width, original_height = img.size
420
+
421
+ # Determine scaling factor to bring short edge to target_size
422
+ shortest_side = min(original_width, original_height)
423
+ if target_size:
424
+ scale = target_size / shortest_side
425
+ else:
426
+ scale = 1.0
427
+
428
+ # Propose new height and width with this scale
429
+ new_height, new_width = int(original_height * scale), int(original_width * scale)
430
+
431
+ # If resulting long edge exceeds max_size, rescale down so that it fits.
432
+ longest_side = max(new_height, new_width)
433
+ if longest_side > max_size:
434
+ scale = max_size / longest_side
435
+ new_height, new_width = int(new_height * scale), int(new_width * scale)
436
+ # Ensure images are at least 28x28 (model may expect it)
437
+ new_width = max(28, new_width)
438
+ new_height = max(28, new_height)
439
+
440
+ # Resize image, using BICUBIC for quality if shape changes
441
+ if new_width == original_width and new_height == original_height:
442
+ resized_img = img
443
+ else:
444
+ resized_img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
445
+ resized_images.append(resized_img)
446
+
447
+ # Transform bounding boxes
448
+ current_transformed_bboxes = []
449
+ scale_ratio_x = new_width / original_width
450
+ scale_ratio_y = new_height / original_height
451
+ for bbox in bboxes:
452
+ x1, y1, x2, y2 = bbox
453
+ new_x1 = x1 * scale_ratio_x
454
+ new_y1 = y1 * scale_ratio_y
455
+ new_x2 = x2 * scale_ratio_x
456
+ new_y2 = y2 * scale_ratio_y
457
+ current_transformed_bboxes.append([new_x1, new_y1, new_x2, new_y2])
458
+ transformed_bbox_lists.append(current_transformed_bboxes)
459
+
460
+ # If original input was a single image (not list), unpack.
461
+ if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4:
462
+ return resized_images, transformed_bbox_lists[0]
463
+ else:
464
+ return resized_images, transformed_bbox_lists
465
+
466
+ def make_message_context(tokenizer, message, chat_format="chatml"):
467
+ """
468
+ Given a message dict, construct the prompt, tokenized context tokens, image URLs, and bbox_list.
469
+
470
+ Handles both standard string 'content' and multi-part (list) content, appropriately placing image/region tokens.
471
+
472
+ Args:
473
+ tokenizer: tokenizer object
474
+ message (dict): Contains role, content, and optionally bbox_list.
475
+ chat_format (str): Optionally select chat format (default 'chatml').
476
+
477
+ Returns:
478
+ tuple: (inp, context_tokens, image_urls, bbox_list)
479
+ """
480
+ image_urls = []
481
+ if chat_format == "chatml":
482
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
483
+ im_start_tokens = [151644]
484
+ im_end_tokens = [151645]
485
+ nl_tokens = tokenizer.encode("\n")
486
+ role = message["role"]
487
+ content = message["content"]
488
+ bbox_list = message.get("bbox_list", None)
489
+
490
+ if role == "system":
491
+ inp = f"{im_start}{role}\n{content}{im_end}\n"
492
+ context_tokens = tokenizer.encode(
493
+ role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, allowed_special=set())
494
+ context_tokens = im_start_tokens + context_tokens + im_end_tokens
495
+
496
+ if role == "user":
497
+ if isinstance(content, str):
498
+ # Plain string message
499
+ inp = f"{im_start}{role}\n{content}{im_end}\n"
500
+ context_tokens = tokenizer.encode(
501
+ role, allowed_special=set()) + nl_tokens + tokenizer.encode(content,
502
+ allowed_special=set())
503
+ context_tokens = im_start_tokens + context_tokens + im_end_tokens
504
+ if isinstance(content, list):
505
+ # Multi-part message (text and image_url parts, maybe region tokens)
506
+ inp = f"{im_start}{role}\n"
507
+ image_count = 1
508
+ for message_part in content:
509
+ if message_part["type"] == "text":
510
+ inp += f"{message_part['text']}"
511
+
512
+ if message_part["type"] == "image_url":
513
+ # Insert special vision/image tokens, possibly region tokens
514
+ inp += DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN + '\n'
515
+ # If regions exist, add per-region special token.
516
+ if bbox_list and len(bbox_list) > 0:
517
+ for idx, bbox in enumerate(bbox_list):
518
+ inp += DEFAULT_REGION_TOKEN.replace('<i>', str(idx)) + DEFAULT_REGION_FEATURE_TOKEN
519
+ inp += '\n'
520
+
521
+ image_urls.append(message_part['image_url']['url'])
522
+ image_count += 1
523
+ inp += f"{im_end}\n"
524
+
525
+ # Choose tokenizer logic based on whether bbox (region) list exists
526
+ if bbox_list and len(bbox_list) > 0:
527
+ context_tokens = tokenizer_image_region_token(inp, tokenizer)
528
+ else:
529
+ context_tokens = tokenizer_image_token(inp, tokenizer, image_token_index=IMAGE_TOKEN_INDEX)
530
+ return inp, context_tokens, image_urls, bbox_list
531
+
532
+ def prepare_inputs(model_name, model, image_processors, tokenizer, messages, device="cuda", max_tokens=512, top_p=1.0, temperature=0.0, do_sample=False):
533
+ """
534
+ Fully prepares keyword arguments for model.generate (and compatible API) from messages and model specs.
535
+
536
+ Handles prompt assembly, tokenization, image loading/preprocessing, region support, streaming, etc.
537
+ Supports specific tweak for Qwen2.5-VL style vision tokens.
538
+
539
+ Args:
540
+ model_name (str): Model identifier string.
541
+ model: Model/config object.
542
+ image_processors (tuple): (primary, auxiliary) image processors.
543
+ tokenizer: Tokenizer object.
544
+ messages (list): Multi-message input list (chat history).
545
+ device (str): Target (usually 'cuda' or 'cpu').
546
+ max_tokens, top_p, temperature, do_sample: Standard generation kwargs.
547
+
548
+ Returns:
549
+ dict: ready-to-use argument dict for model.generate().
550
+ """
551
+ # For Qwen2.5-VL, patch vision special tokens globally.
552
+ if 'qwen2.5-vl' in model_name.lower() or 'qwen2_5_vl' in model_name.lower():
553
+ global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
554
+ DEFAULT_IM_START_TOKEN = "<|vision_start|>"
555
+ DEFAULT_IM_END_TOKEN = "<|vision_end|>"
556
+
557
+ primary_image_processor, auxiliary_image_processor = image_processors
558
+
559
+ prompt = ""
560
+ input_tokens = []
561
+ image_urls = []
562
+ # Compose prompt and accumulate all components from provided messages
563
+ for message in messages:
564
+ inp, context_tokens, image_urls, bbox_list = make_message_context(tokenizer, message)
565
+ prompt += inp
566
+ input_tokens.extend(context_tokens)
567
+
568
+ # Ensure a system prompt at start, if not already present.
569
+ if "system" not in prompt:
570
+ system_content = "system\nYou are a helpful assistant."
571
+ system_prompt = "<|im_start|>" + system_content + "<|im_end|>" + "\n"
572
+ prompt = system_prompt + prompt
573
+ system_tokens = [151644] + tokenizer(system_content).input_ids + [151645] + tokenizer("\n").input_ids
574
+ input_tokens = system_tokens + input_tokens
575
+
576
+ # Ensure prompt ends with assistant's turn.
577
+ if not prompt.endswith("<|im_start|>assistant"):
578
+ last_assistant_prompt = "<|im_start|>" + "assistant" + "\n"
579
+ prompt += last_assistant_prompt
580
+ # last_assistant_tokens = [6] + self.tokenizer("assistant\n").input_ids
581
+ last_assistant_tokens = [151644] + tokenizer("assistant\n").input_ids
582
+ input_tokens.extend(last_assistant_tokens)
583
+
584
+ primary_images_tensor = None
585
+ auxiliary_images_tensor = None
586
+ primary_image_grid_thws = None
587
+ if image_urls:
588
+ # Load images, resize them, and update bbox_list downstream
589
+ images = [load_image(i) for i in image_urls]
590
+ # print('original images[0].size:', images[0].size)
591
+ images, bbox_list = resize_shortest_edge_images_and_bboxes(images, bbox_list, max_size=2048)
592
+ # print('resized images[0].size:', images[0].size)
593
+
594
+ # When region-indexed tokens are enabled
595
+ if getattr(model.config, 'mm_use_region_index_token', False):
596
+ origin_image_size = [image.size for image in images]
597
+ aux_images = images.copy()
598
+ auxiliary_images_tensor = [auxiliary_image_processor.preprocess(i, return_tensors='pt')['pixel_values'][0].to(device) for i in aux_images]
599
+
600
+ if bbox_list and len(bbox_list) > 0:
601
+ # Limit number of bbox (for computational constraints, etc.)
602
+ bbox_list = bbox_list[:100]
603
+ resize_h, resize_w = auxiliary_images_tensor[0].shape[-2:]
604
+ original_w, original_h = origin_image_size[0]
605
+ # Adjust bbox to match resized images (post pre-processing)
606
+ bbox_list = adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w)
607
+ bbox_list = [torch.tensor(bbox_list)]
608
+ else:
609
+ bbox_list = None
610
+ else:
611
+ auxiliary_images_tensor = None
612
+
613
+ # Preprocess primary images for main vision model branch
614
+ primary_images = []
615
+ primary_image_grid_thws = []
616
+ for im in images:
617
+ processed_data = primary_image_processor.preprocess(im, videos=None, return_tensors="pt")
618
+ image_i = processed_data['pixel_values']
619
+ image_grid_thw_i = processed_data['image_grid_thw']
620
+ primary_images.append(image_i)
621
+ primary_image_grid_thws.append(image_grid_thw_i)
622
+ primary_images_tensor = [image_i.to(device) for image_i in primary_images]
623
+
624
+ # For Qwen-style, force specific end-token as stopping criterion
625
+ if "qwen" in model_name.lower():
626
+ input_ids = torch.tensor([input_tokens]).to(device)
627
+ keywords = ["<|im_end|>"]
628
+
629
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
630
+ streamer = TextStreamer(
631
+ tokenizer, skip_prompt=True, skip_special_tokens=True
632
+ )
633
+
634
+ # Default: greedy decoding if temperature=0. Else: enable sampling.
635
+ if temperature == 0.0:
636
+ do_sample = False
637
+ else:
638
+ do_sample = True
639
+
640
+ print("question:================\n", prompt, "\n=================")
641
+ # print("input ids:========", input_ids, "========")
642
+ generation_kwargs = dict(
643
+ inputs=input_ids,
644
+ images=primary_images_tensor,
645
+ images_aux=auxiliary_images_tensor,
646
+ image_grid_thws=primary_image_grid_thws,
647
+ bbox_list=bbox_list,
648
+ do_sample=do_sample,
649
+ temperature=temperature,
650
+ max_new_tokens=max_tokens,
651
+ streamer=streamer,
652
+ top_p=top_p,
653
+ use_cache=True,
654
+ stopping_criteria=[stopping_criteria],
655
+ pad_token_id=tokenizer.pad_token_id
656
+ )
657
+ return generation_kwargs
658
+
vlm_fo1/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.omchat_qwen2_5_vl import OmChatQwen25VLForCausalLM, OmChatQwen25VLConfig
vlm_fo1/model/builder.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torch
3
+ from vlm_fo1.model import *
4
+ from safetensors.torch import load_file
5
+ import os
6
+
7
+
8
+ def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"):
9
+ """
10
+ Loads a pretrained model along with its vision towers (and associated image processors).
11
+ This function supports loading in 8bit/4bit precision and explicit device placement.
12
+
13
+ Args:
14
+ model_path (str): Path to the pretrained model directory.
15
+ load_8bit (bool): Whether to load the model in 8bit mode.
16
+ load_4bit (bool): Whether to load the model in 4bit mode.
17
+ device (str): Device to load model onto, e.g., "cuda" or "cpu".
18
+
19
+ Returns:
20
+ tuple: (tokenizer, model, image_processor)
21
+ """
22
+ kwargs = {"device_map": device}
23
+
24
+ # Set model loading parameters for quantization or floating point
25
+ if load_8bit:
26
+ kwargs['load_in_8bit'] = True
27
+ elif load_4bit:
28
+ kwargs['load_in_4bit'] = True
29
+ else:
30
+ kwargs['torch_dtype'] = torch.bfloat16
31
+
32
+ # print(model_path)
33
+
34
+ # Only proceed for vlm-fo1 models
35
+ if 'vlm-fo1' in model_path.lower():
36
+ # Load tokenizer (slow tokenizer enforced)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
38
+ # If this is the Qwen2.5-VL variant, load with additional kwargs
39
+ if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
40
+ model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained(
41
+ model_path,
42
+ low_cpu_mem_usage=True,
43
+ output_loading_info=True,
44
+ attn_implementation="flash_attention_2",
45
+ **kwargs,
46
+ cache_dir='./resources',
47
+ )
48
+ # print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}')
49
+ # (For other variants of vlm-fo1, model loading detail may need additional condition.)
50
+
51
+ if 'vlm-fo1' in model_path.lower():
52
+ # --- Vision Tower Loading ---
53
+ # Load the main vision tower weights from model_path if it is not yet loaded
54
+ primary_vision_tower = model.get_vision_tower()
55
+ if primary_vision_tower and not primary_vision_tower.is_loaded:
56
+ primary_vision_tower.load_model(model_path=model_path, is_train=False)
57
+ primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype
58
+
59
+ # Grab primary image processor from vision tower, if present
60
+ if primary_vision_tower:
61
+ primary_image_processor = primary_vision_tower.image_processor
62
+
63
+ # --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) ---
64
+ if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
65
+ try:
66
+ aux_image_size = model.config.aux_image_size
67
+ except Exception:
68
+ # If aux_image_size is missing from config fallback to 768
69
+ aux_image_size = 768
70
+
71
+ aux_image_aspect_ratio = model.config.aux_image_aspect_ratio
72
+ aux_vision_tower = model.get_vision_tower_aux()
73
+ # Only load if not already loaded
74
+ if aux_vision_tower and not aux_vision_tower.is_loaded:
75
+ aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio)
76
+ aux_vision_tower.to(device=device, dtype=torch.bfloat16)
77
+
78
+ # Get auxiliary image processor if there is an aux vision tower
79
+ if aux_vision_tower:
80
+ aux_image_processor = aux_vision_tower.image_processor
81
+ else:
82
+ image_processor = None # Set to None if there is no auxiliary vision tower
83
+
84
+ # image_processor returned as a tuple of (primary, aux)
85
+ image_processor = (primary_image_processor, aux_image_processor)
86
+
87
+ # --- Ensure vision_tower and vision_tower_aux are loaded with weights from model_path ---
88
+ # if 'vlm-fo1' in model_path.lower():
89
+ # print(f"Loading weights from {model_path} to ensure vision_tower uses the correct weights...") # Inform user we are loading vision weights
90
+
91
+ # # --- Gather all safetensors files in the model path (for sharded checkpoints) ---
92
+ # state_dict = {}
93
+ # safetensor_files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')]
94
+
95
+ # if safetensor_files:
96
+ # for safetensor_file in safetensor_files:
97
+ # file_path = os.path.join(model_path, safetensor_file)
98
+ # shard_state_dict = load_file(file_path, device="cpu")
99
+ # state_dict.update(shard_state_dict)
100
+ # else:
101
+ # # Fallback to legacy .bin checkpoint if no safetensors found
102
+ # state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
103
+
104
+ # # --- Filter out only vision_tower and vision_tower_aux related weights ---
105
+ # vision_tower_keys = [k for k in state_dict.keys() if "vision_tower." in k]
106
+ # vision_tower_state_dict = {k: state_dict[k] for k in vision_tower_keys if k in state_dict}
107
+
108
+ # if vision_tower_keys:
109
+ # # print(f"Found {len(vision_tower_keys)} vision_tower weights")
110
+ # # Load weights into main vision tower
111
+ # if primary_vision_tower and primary_vision_tower.is_loaded:
112
+ # # Strips the prefix "model.vision_tower." before loading (for compatibility with submodules)
113
+ # missing_keys, unexpected_keys = primary_vision_tower.load_state_dict(
114
+ # {k.replace("model.vision_tower.", ""): v for k, v in vision_tower_state_dict.items()
115
+ # if k.startswith("model.vision_tower.")},
116
+ # strict=True
117
+ # )
118
+ # print(f"vision_tower weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}")
119
+
120
+ # # If there is an aux vision tower (Qwen2.5-VL) load its weights as well
121
+ # if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
122
+ # if aux_vision_tower and aux_vision_tower.is_loaded:
123
+ # vision_tower_aux_keys = [k for k in state_dict.keys() if "vision_tower_aux." in k]
124
+ # if vision_tower_aux_keys:
125
+ # # print(f"Found {len(vision_tower_aux_keys)} vision_tower_aux weights")
126
+ # vision_tower_aux_state_dict = {k: state_dict[k] for k in vision_tower_aux_keys if k in state_dict}
127
+ # # Strip "model.vision_tower_aux." prefix before loading for compatibility
128
+ # missing_keys, unexpected_keys = aux_vision_tower.load_state_dict(
129
+ # {k.replace("model.vision_tower_aux.", ""): v for k, v in vision_tower_aux_state_dict.items()
130
+ # if k.startswith("model.vision_tower_aux.")},
131
+ # strict=True
132
+ # )
133
+ # print(f"vision_tower_aux weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}")
134
+
135
+ # else:
136
+ # # If no vision tower weights found, raise an error
137
+ # print("No vision_tower weights found")
138
+ # raise Exception("No vision_tower weights found")
139
+
140
+ # Set model to eval mode and move to correct device before returning
141
+ model.eval()
142
+ model.to(device=device, dtype=torch.bfloat16)
143
+ return tokenizer, model, image_processor
vlm_fo1/model/language_model/omchat_qwen2_5_vl.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import Qwen2_5_VLConfig, AutoConfig, AutoModelForCausalLM
7
+ from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
8
+ from vlm_fo1.model.multimodal_encoder.qwen2_5_vl_encoder import Qwen2_5_VlVisionTower
9
+ from vlm_fo1.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_REGION_INDEX, QWEN2_5_VL_IMAGE_TOKEN, QWEN2_5_VL_IMAGE_TOKEN_INDEX
10
+
11
+ from ..omchat_arch import OmChatMetaModel, OmChatMetaForCausalLM
12
+
13
+ # Custom config which extends Qwen2_5_VLConfig for OmChat multimodal model
14
+ class OmChatQwen25VLConfig(Qwen2_5_VLConfig):
15
+ model_type = "omchat_qwen2_5_vl"
16
+ rotary_type = "normal_rotary"
17
+ multi_scale_im = None
18
+ vision_tower_aux = None
19
+
20
+ # Core model definition: inherits from OmChat and Qwen multimodal base
21
+ class OmChatQwen25VLModel(OmChatMetaModel, Qwen2_5_VLModel):
22
+ config_class = OmChatQwen25VLConfig
23
+
24
+ def __init__(self, config: Qwen2_5_VLConfig):
25
+ super(OmChatQwen25VLModel, self).__init__(config)
26
+
27
+ # Main class for multimodal CausalLM
28
+ class OmChatQwen25VLForCausalLM(Qwen2_5_VLForConditionalGeneration, OmChatMetaForCausalLM):
29
+ config_class = OmChatQwen25VLConfig
30
+
31
+ def __init__(self, config, delay_load=True):
32
+ # Ensure config has delay_load property
33
+ if not hasattr(config, 'delay_load'):
34
+ config.delay_load = delay_load
35
+ super(Qwen2_5_VLForConditionalGeneration, self).__init__(config)
36
+ self.model = OmChatQwen25VLModel(config)
37
+ self.vocab_size = config.vocab_size
38
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
39
+ self.rope_deltas = None # cache rope_deltas here
40
+
41
+ self.post_init()
42
+
43
+ # Encode input images into feature representations
44
+ def encode_images(self, images, images_grid_thw=None):
45
+ # If vision_tower is Qwen2.5-specific, use its custom forward signature
46
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
47
+ image_features = self.get_model().get_vision_tower()(images, images_grid_thw)
48
+ image_features, image_grid_thws, multi_level_features = image_features
49
+ # If multiple images, handle concatenation
50
+ if type(image_features) is list:
51
+ # List has items of shape (1, seq_len, dim)
52
+ token_length_list = [i.shape[1] for i in image_features]
53
+ image_features = torch.cat(image_features, dim=1) # Concatenate to (1, total_seq_len, dim)
54
+ else:
55
+ image_features = self.get_model().get_vision_tower()(images)
56
+ image_grid_thws = None
57
+ multi_level_features = None
58
+
59
+ image_features = self.get_model().mm_projector(image_features)
60
+
61
+ # Split concatenated image features back by original lengths (for multi-image case)
62
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
63
+ start = 0
64
+ new_image_features = []
65
+ # Split according to token_length_list
66
+ for length in token_length_list:
67
+ end = start + length
68
+ new_image_features.append(image_features[:, start:end, :].squeeze(0))
69
+ start = end
70
+ image_features = new_image_features
71
+
72
+ return image_features, image_grid_thws, multi_level_features
73
+
74
+ # Encode region regions (bounding boxes) into features, optionally using auxiliary vision tower
75
+ def encode_regions(self, images, bbox_list, vt_multi_level_features=None, vt_images_size=None):
76
+ aux_image_features_list = self.get_model().get_vision_tower_aux()(images)
77
+ region_features = []
78
+ if getattr(self.config, "mm_use_vision_tower_region_feature", False):
79
+ image_features_list = vt_multi_level_features
80
+ for batch_idx, (image_features, aux_image_features) in enumerate(zip(image_features_list, aux_image_features_list)):
81
+
82
+ if getattr(self.config, "mm_use_simpleFPN_for_vt", False):
83
+ multilevel_visual_feats = image_features[-1]
84
+ else:
85
+ multilevel_visual_feats = image_features
86
+ multilevel_aux_visual_feats = aux_image_features["image_features"]
87
+ boxes = bbox_list[batch_idx]
88
+
89
+ # If no boxes provided, use dummy box (covers tiny region)
90
+ if boxes is None or len(boxes) == 0:
91
+ boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_aux_visual_feats[0].device, dtype=torch.float32)
92
+
93
+ boxes = boxes.to(torch.float32).to(multilevel_aux_visual_feats[0].device)
94
+ current_image_height, current_image_width = images[batch_idx].shape[-2:]
95
+ original_height, original_width = vt_images_size[batch_idx]
96
+ # Scale bounding boxes from original image size to processed size
97
+ scale_height = original_height / current_image_height
98
+ scale_width = original_width / current_image_width
99
+ vt_boxes = boxes * torch.tensor([scale_width, scale_height, scale_width, scale_height], device=boxes.device)
100
+
101
+ extracted_region_feat = self.get_model().object_vp_extractor(
102
+ aux_multi_level_features=multilevel_aux_visual_feats,
103
+ vt_multi_level_features=multilevel_visual_feats,
104
+ aux_boxes=[boxes],
105
+ vt_boxes=[vt_boxes]
106
+ ).squeeze(0).to(multilevel_aux_visual_feats[0].dtype)
107
+ region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2048]
108
+ region_features.append(region_feat)
109
+ else:
110
+ # Extract region features only from auxiliary vision tower
111
+ for batch_idx, image_features in enumerate(aux_image_features_list):
112
+ multilevel_visual_feats = image_features["image_features"]
113
+ last_feat = image_features["last_feat"]
114
+ boxes = bbox_list[batch_idx]
115
+
116
+ if boxes is None or len(boxes) == 0:
117
+ boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_visual_feats[0].device, dtype=torch.float32)
118
+
119
+ multi_level_aux_features = multilevel_visual_feats
120
+ boxes = boxes.to(torch.float32).to(multi_level_aux_features[0].device)
121
+ extracted_region_feat = self.get_model().object_vp_extractor(
122
+ multi_level_aux_features,
123
+ [boxes],
124
+ ).squeeze(0).to(multi_level_aux_features[0].dtype)
125
+ region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2880]
126
+ region_features.append(region_feat)
127
+
128
+ return region_features
129
+
130
+ def get_model(self):
131
+ # Getter for model. Used to access backbone/model internals.
132
+ return self.model
133
+
134
+ # Convert sequence of input_ids/labels/images/boxes to multimodal embedding and associated masks/ids for transformer input.
135
+ def prepare_inputs_labels_for_qwen2_5_vl_multimodal(
136
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux=None, bbox_list=None, image_grid_thws=None
137
+ ):
138
+ # ========================== Above this line, input parsing and batching =============================
139
+ vision_tower = self.get_vision_tower()
140
+ video_tower = self.get_video_tower()
141
+ vision_tower_aux = self.get_vision_tower_aux()
142
+ # Fast-path for non-multimodal case or first step in generation (i.e. only one token in input)
143
+ if (vision_tower is None and video_tower is None) or images is None or input_ids.shape[1] == 1:
144
+ if past_key_values is not None and (vision_tower is not None or video_tower is not None) and images is not None and input_ids.shape[1] == 1:
145
+
146
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
147
+ attention_mask = torch.cat((attention_mask, torch.ones(
148
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
149
+ dtype=attention_mask.dtype,
150
+ device=attention_mask.device
151
+ )), dim=1)
152
+
153
+ position_ids=None
154
+ cache_position = torch.tensor([target_shape - 1],device=attention_mask.device)
155
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, cache_position
156
+
157
+ # Indices for images (3D or 2D tensors) and videos (4D tensors)
158
+ image_idx = [idx for idx, img in enumerate(images) if img.ndim == 3 or img.ndim == 2]
159
+ is_all_image = len(image_idx) == len(images)
160
+ video_idx = [idx for idx, vid in enumerate(images) if vid.ndim == 4]
161
+
162
+ # Stack image and video tensors accordingly for mini-batch processing
163
+ if isinstance(vision_tower, Qwen2_5_VlVisionTower):
164
+ images_minibatch = [images[idx] for idx in image_idx] if len(image_idx) > 0 else [] # list of [c,h,w], can have variable shapes
165
+ else:
166
+ images_minibatch = torch.stack([images[idx] for idx in image_idx]) if len(image_idx) > 0 else [] # tensor [mini_b, c, h, w]
167
+ videos_minibatch = torch.stack([images[idx] for idx in video_idx]) if len(video_idx) > 0 else [] # tensor [mini_b, c, t, h, w]
168
+
169
+ # Auxiliary batch for region encoding, if relevant
170
+ if vision_tower_aux is not None and images_aux is not None:
171
+ images_minibatch_aux = [images_aux[idx].unsqueeze(0) for idx in image_idx] if len(image_idx) > 0 else [] # list of [1, c, h, w]
172
+
173
+ # tmp_image_features will be indexed to scatter extracted image/video features into original batch positions
174
+ tmp_image_features = [None] * (len(image_idx) + len(video_idx))
175
+ if getattr(images_minibatch, 'ndim', 0) == 4 or (type(images_minibatch) is list and len(images_minibatch) > 0): # batch consists of images, [mini_b, c, h, w]
176
+ if vision_tower is not None:
177
+ image_features_minibatch, image_grid_thws_minibatch, vt_multi_level_features_minibatch = self.encode_images(images_minibatch, image_grid_thws) # [mini_b, l, c]
178
+ else:
179
+ image_features_minibatch = torch.randn(1).to(self.device) # dummy feature for video-only training under tuning
180
+
181
+ # Map extracted image features back to their places in the original batch
182
+ for i, pos in enumerate(image_idx):
183
+ tmp_image_features[pos] = image_features_minibatch[i]
184
+
185
+ # Handle auxiliary region features if enabled and boxes provided
186
+ if vision_tower_aux is not None and bbox_list is not None and len(bbox_list) > 0:
187
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
188
+ patch_size = self.get_model().get_vision_tower().config.patch_size
189
+ vt_images_size_minibatch = [im_grid_thw[0][-2:]*patch_size for im_grid_thw in image_grid_thws]
190
+ region_features = self.encode_regions(images_minibatch_aux, bbox_list, vt_multi_level_features_minibatch, vt_images_size_minibatch) # [mini_b, l, c]
191
+ else:
192
+ region_features = None
193
+
194
+ # Same as above, but for video features if any
195
+ if getattr(videos_minibatch, 'ndim', 0) == 5: # batch consists of videos, [mini_b, c, t, h, w]
196
+ video_features_minibatch = self.encode_videos(videos_minibatch) # fake list [mini_b, t, l, c]
197
+ for i, pos in enumerate(video_idx):
198
+ tmp_image_features[pos] = video_features_minibatch[i]
199
+
200
+ # Flatten image feature slot list to proper order for current batch
201
+ new_tmp = []
202
+ for image in tmp_image_features:
203
+ # If multi-image per item, flatten out
204
+ if isinstance(image, list):
205
+ t = len(image)
206
+ for i in range(t):
207
+ new_tmp.append(image[i])
208
+ else:
209
+ new_tmp.append(image)
210
+ image_features = new_tmp
211
+
212
+ # =========================== Now, build multimodal input & target sequences =========================
213
+
214
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
215
+ raise NotImplementedError
216
+
217
+ _labels = labels
218
+ _position_ids = position_ids
219
+ _attention_mask = attention_mask
220
+
221
+ # Default construction of masks etc.
222
+ if attention_mask is None:
223
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
224
+ else:
225
+ attention_mask = attention_mask.bool()
226
+ if position_ids is None:
227
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
228
+ if labels is None:
229
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
230
+
231
+ # For each batch item, strip padded tokens based on attention_mask
232
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
233
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
234
+
235
+ # If neither region auxiliary nor bboxes present: process classic image-text input
236
+ if vision_tower_aux is None and (bbox_list is None or all(x is None for x in bbox_list)):
237
+ new_input_embeds = []
238
+ new_labels = []
239
+ new_input_ids = []
240
+ cur_image_idx = 0
241
+ image_nums_in_batch = []
242
+
243
+ for batch_idx, cur_input_ids in enumerate(input_ids):
244
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
245
+ image_nums_in_batch.append(num_images)
246
+ # If there are no image markers, just get text features
247
+ if num_images == 0:
248
+ cur_image_features = image_features[cur_image_idx]
249
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
250
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
251
+ new_input_embeds.append(cur_input_embeds)
252
+ new_labels.append(labels[batch_idx])
253
+ new_input_ids.append(cur_input_ids)
254
+ cur_image_idx += 1
255
+ continue
256
+
257
+ # Split on image token indices: replace them with image features after conversion
258
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
259
+ cur_input_ids_noim = []
260
+ cur_labels = labels[batch_idx]
261
+ cur_labels_noim = []
262
+ for i in range(len(image_token_indices) - 1):
263
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
264
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
265
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
266
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
267
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
268
+
269
+ cur_new_input_embeds = []
270
+ cur_new_labels = []
271
+ cur_new_input_ids = []
272
+ for i in range(num_images + 1):
273
+ # Interleave text and image features
274
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
275
+ cur_new_labels.append(cur_labels_noim[i])
276
+ cur_new_input_ids.append(cur_input_ids_noim[i])
277
+ if i < num_images:
278
+ cur_image_features = image_features[cur_image_idx].to(self.device)
279
+ cur_image_idx += 1
280
+ cur_new_input_embeds.append(cur_image_features)
281
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
282
+ cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
283
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
284
+ cur_new_labels = torch.cat(cur_new_labels)
285
+ cur_new_input_ids = torch.cat(cur_new_input_ids)
286
+
287
+ new_input_embeds.append(cur_new_input_embeds)
288
+ new_labels.append(cur_new_labels)
289
+ new_input_ids.append(cur_new_input_ids)
290
+ # If region markers or region features enabled in config
291
+ else:
292
+ new_input_embeds = []
293
+ new_labels = []
294
+ new_input_ids = []
295
+ cur_image_idx = 0
296
+ image_nums_in_batch = []
297
+
298
+ for batch_idx, cur_input_ids in enumerate(input_ids):
299
+ cur_region_idx = 0
300
+ # Detect image and region special token counts
301
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
302
+ num_regions = (cur_input_ids == DEFAULT_REGION_INDEX).sum() if DEFAULT_REGION_INDEX in cur_input_ids else 0
303
+ image_nums_in_batch.append(num_images)
304
+
305
+ # If no markers, just do text embedding for this item
306
+ if num_images == 0 and num_regions == 0:
307
+ cur_image_features = image_features[cur_image_idx]
308
+ cur_region_features = region_features[cur_region_idx]
309
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
310
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_region_features[0:0]], dim=0)
311
+ new_input_embeds.append(cur_input_embeds)
312
+ new_labels.append(labels[batch_idx])
313
+ new_input_ids.append(cur_input_ids)
314
+ cur_image_idx += 1
315
+ continue
316
+
317
+ # Get all special marker indices (image/region)
318
+ image_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
319
+ region_indices = torch.where(cur_input_ids == DEFAULT_REGION_INDEX)[0].tolist() if num_regions > 0 else []
320
+ all_special_indices = sorted([-1] + image_indices + region_indices + [cur_input_ids.shape[0]])
321
+
322
+ # Split out plain text chunks between special markers
323
+ cur_input_ids_segments = []
324
+ cur_labels = labels[batch_idx]
325
+ cur_labels_segments = []
326
+
327
+ for i in range(len(all_special_indices) - 1):
328
+ cur_input_ids_segments.append(cur_input_ids[all_special_indices[i]+1:all_special_indices[i+1]])
329
+ cur_labels_segments.append(cur_labels[all_special_indices[i]+1:all_special_indices[i+1]])
330
+
331
+ # Project text ids to word embeddings
332
+ split_sizes = [x.shape[0] for x in cur_labels_segments]
333
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_segments))
334
+ if num_regions == 0 and vision_tower_aux is not None and region_features is not None:
335
+ cur_region_features = region_features[cur_region_idx]
336
+ temp_input_embeds = torch.cat([cur_input_embeds, cur_region_features[0:0]], dim=0)
337
+ cur_input_embeds = temp_input_embeds
338
+
339
+ cur_input_embeds_segments = torch.split(cur_input_embeds, split_sizes, dim=0)
340
+
341
+ # Reassemble text and image/region segments in order
342
+ cur_new_input_embeds = []
343
+ cur_new_labels = []
344
+ cur_new_input_ids = []
345
+
346
+ for i in range(len(all_special_indices) - 1):
347
+ # Insert current text segment
348
+ cur_new_input_embeds.append(cur_input_embeds_segments[i])
349
+ cur_new_labels.append(cur_labels_segments[i])
350
+ cur_new_input_ids.append(cur_input_ids_segments[i])
351
+ # If next is image, insert feature representation
352
+ if all_special_indices[i+1] in image_indices:
353
+ cur_image_features = image_features[cur_image_idx].to(self.device)
354
+ cur_image_idx += 1
355
+ cur_new_input_embeds.append(cur_image_features)
356
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
357
+ cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
358
+
359
+ # If next is region token, insert extracted region features
360
+ elif all_special_indices[i+1] in region_indices:
361
+ cur_region_features = region_features[batch_idx][cur_region_idx].to(self.device).unsqueeze(0)
362
+ cur_region_idx += 1
363
+ cur_new_input_embeds.append(cur_region_features)
364
+
365
+ cur_new_labels.append(torch.full((cur_region_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
366
+ cur_new_input_ids.append(torch.full((cur_region_features.shape[0],), DEFAULT_REGION_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
367
+ # Combine for this batch item
368
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
369
+ cur_new_labels = torch.cat(cur_new_labels)
370
+ cur_new_input_ids = torch.cat(cur_new_input_ids)
371
+ new_input_embeds.append(cur_new_input_embeds)
372
+ new_labels.append(cur_new_labels)
373
+ new_input_ids.append(cur_new_input_ids)
374
+ # Truncate sequences to maximum model length, if image+region tokens caused overflow
375
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
376
+ if tokenizer_model_max_length is not None:
377
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
378
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
379
+
380
+ # Pad sequences in the batch to same length; compute batch masks
381
+ max_len = max(x.shape[0] for x in new_input_embeds)
382
+ batch_size = len(new_input_embeds)
383
+
384
+ new_input_embeds_padded = []
385
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
386
+ new_input_ids_padded = torch.full((batch_size, max_len), self.config.bos_token_id, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
387
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
388
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
389
+
390
+ # Left or right padding as per config; fill padded tensors
391
+ for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)):
392
+ cur_len = cur_new_embed.shape[0]
393
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
394
+ # Left pad: add zeros before text tokens/features
395
+ new_input_embeds_padded.append(torch.cat((
396
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
397
+ cur_new_embed
398
+ ), dim=0))
399
+ if cur_len > 0:
400
+ new_labels_padded[i, -cur_len:] = cur_new_labels
401
+ attention_mask[i, -cur_len:] = True
402
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
403
+ else:
404
+ # Right pad: add zeros after text tokens/features
405
+ new_input_embeds_padded.append(torch.cat((
406
+ cur_new_embed,
407
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
408
+ ), dim=0))
409
+ if cur_len > 0:
410
+ new_labels_padded[i, :cur_len] = cur_new_labels
411
+ new_input_ids_padded[i, :cur_len] = cur_new_input_ids
412
+ attention_mask[i, :cur_len] = True
413
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
414
+
415
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
416
+ new_input_ids = new_input_ids_padded
417
+
418
+ # Only set new_labels if original labels were not None
419
+ if _labels is None:
420
+ new_labels = None
421
+ else:
422
+ new_labels = new_labels_padded
423
+
424
+ # Similarly handle provided attention_mask/position_ids overrides
425
+ if _attention_mask is None:
426
+ attention_mask = None
427
+ else:
428
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
429
+
430
+ if _position_ids is None:
431
+ position_ids = None
432
+
433
+ # For Qwen2.5 vision towers, use and concatenate image_grid_thws for positional computations
434
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
435
+ image_grid_thws = []
436
+ cur_image_idx = 0
437
+ for num_images in image_nums_in_batch:
438
+ if num_images == 0:
439
+ cur_image_idx += 1
440
+ continue
441
+ image_grid_thws += image_grid_thws_minibatch[cur_image_idx:cur_image_idx+num_images]
442
+ cur_image_idx += num_images
443
+
444
+ if len(image_grid_thws) > 0:
445
+ image_grid_thws = torch.cat(image_grid_thws, dim=0)
446
+ else:
447
+ image_grid_thws = None
448
+
449
+ rope_index_kwargs = {
450
+ "input_ids": new_input_ids,
451
+ "image_grid_thw": image_grid_thws,
452
+ "video_grid_thw": None,
453
+ "attention_mask": attention_mask,
454
+ }
455
+
456
+ # Compute new position_ids and rope_deltas for transformer (for rotary embeddings)
457
+ position_ids, rope_deltas = self.get_rope_index(**rope_index_kwargs)
458
+ cache_position = torch.arange(new_input_embeds.shape[1], device=new_input_embeds.device)
459
+ else:
460
+ rope_deltas = None
461
+ cache_position = None
462
+ # Final output is a tuple mimicking HuggingFace prepare_inputs_for_generation return
463
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, rope_deltas, cache_position
464
+
465
+ # Patch forward() of HF CausalLM to allow multimodal embedding with images/regions
466
+ def forward(
467
+ self,
468
+ input_ids: torch.LongTensor = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ labels: Optional[torch.LongTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: Optional[bool] = None,
478
+ pixel_values: Optional[torch.Tensor] = None,
479
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
480
+ image_grid_thw: Optional[torch.LongTensor] = None,
481
+ video_grid_thw: Optional[torch.LongTensor] = None,
482
+ rope_deltas: Optional[torch.LongTensor] = None,
483
+ cache_position: Optional[torch.LongTensor] = None,
484
+ second_per_grid_ts: Optional[torch.Tensor] = None,
485
+ images: Optional[torch.FloatTensor] = None,
486
+ images_aux: Optional[torch.FloatTensor] = None,
487
+ bbox_list: Optional[torch.FloatTensor] = None,
488
+ image_grid_thws: Optional[torch.FloatTensor] = None,
489
+ ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
490
+
491
+ if inputs_embeds is None:
492
+ (
493
+ input_ids,
494
+ position_ids,
495
+ attention_mask,
496
+ past_key_values,
497
+ inputs_embeds,
498
+ labels,
499
+ rope_deltas,
500
+ cache_position
501
+ ) = self.prepare_inputs_labels_for_qwen2_5_vl_multimodal(
502
+ input_ids,
503
+ position_ids,
504
+ attention_mask,
505
+ past_key_values,
506
+ labels,
507
+ images,
508
+ images_aux,
509
+ bbox_list,
510
+ image_grid_thws
511
+ )
512
+
513
+ if rope_deltas is not None:
514
+ self.rope_deltas = rope_deltas
515
+
516
+ # Call base CausalLM forward, with possibly replaced multimodal embeddings
517
+ out = super().forward(
518
+ input_ids=input_ids,
519
+ attention_mask=attention_mask,
520
+ position_ids=position_ids,
521
+ past_key_values=past_key_values,
522
+ inputs_embeds=inputs_embeds,
523
+ labels=labels,
524
+ use_cache=use_cache,
525
+ output_attentions=output_attentions,
526
+ output_hidden_states=output_hidden_states,
527
+ rope_deltas=rope_deltas,
528
+ cache_position=cache_position,
529
+ second_per_grid_ts=second_per_grid_ts,
530
+ return_dict=return_dict
531
+ )
532
+ return out
533
+
534
+ # Prepare model input dict for autoregressive generation (for use with generation methods like generate())
535
+ def prepare_inputs_for_generation(
536
+ self,
537
+ input_ids,
538
+ past_key_values=None,
539
+ attention_mask=None,
540
+ inputs_embeds=None,
541
+ cache_position=None,
542
+ position_ids=None,
543
+ use_cache=True,
544
+ pixel_values=None,
545
+ pixel_values_videos=None,
546
+ image_grid_thw=None,
547
+ video_grid_thw=None,
548
+ second_per_grid_ts=None,
549
+ images: Optional[torch.FloatTensor] = None,
550
+ images_aux: Optional[torch.FloatTensor] = None,
551
+ bbox_list: Optional[torch.FloatTensor] = None,
552
+ image_grid_thws: Optional[torch.FloatTensor] = None,
553
+ **kwargs,
554
+ ):
555
+ # Wrap parent logic so extra multimodal kwargs are preserved
556
+ model_inputs = super().prepare_inputs_for_generation(
557
+ input_ids,
558
+ past_key_values=past_key_values,
559
+ attention_mask=attention_mask,
560
+ inputs_embeds=inputs_embeds,
561
+ cache_position=cache_position,
562
+ pixel_values=pixel_values,
563
+ pixel_values_videos=pixel_values_videos,
564
+ image_grid_thw=image_grid_thw,
565
+ video_grid_thw=video_grid_thw,
566
+ second_per_grid_ts=second_per_grid_ts,
567
+ images=images,
568
+ images_aux=images_aux,
569
+ bbox_list=bbox_list,
570
+ image_grid_thws=image_grid_thws,
571
+ )
572
+ return model_inputs
573
+
574
+ # Register our config and model with HuggingFace transformers registry
575
+ AutoConfig.register("omchat_qwen2_5_vl", OmChatQwen25VLConfig)
576
+ AutoModelForCausalLM.register(OmChatQwen25VLConfig, OmChatQwen25VLForCausalLM)
vlm_fo1/model/multimodal_encoder/__init__.py ADDED
File without changes
vlm_fo1/model/multimodal_encoder/base_encoder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AbsVisionTower(nn.Module):
6
+ @torch.no_grad()
7
+ def forward(self, images):
8
+ raise NotImplementedError
9
+
10
+ @property
11
+ def dummy_feature(self):
12
+ raise NotImplementedError
13
+
14
+ @property
15
+ def dtype(self):
16
+ raise NotImplementedError
17
+
18
+ @property
19
+ def device(self):
20
+ raise NotImplementedError
21
+
22
+ @property
23
+ def config(self):
24
+ raise NotImplementedError
25
+
26
+
27
+ @property
28
+ def hidden_size(self):
29
+ raise NotImplementedError
30
+
31
+ @property
32
+ def num_patches(self):
33
+ raise NotImplementedError
vlm_fo1/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Builders for different vision tower backbones (MM encoder visual modules)
2
+ from .qwen2_5_vl_encoder import Qwen2_5_VlVisionTower # Main Qwen2.5 vision tower
3
+ from .davit_aux_encoder import DavitVisionTower as DavitVisionTowerAux # Auxiliary DaViT vision tower
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ """
7
+ Use model config to construct the main vision tower.
8
+
9
+ vision_tower_cfg: should have attribute mm_vision_tower
10
+ Returns: instance of configured vision backbone
11
+ """
12
+ vision_tower_name = getattr(vision_tower_cfg, 'mm_vision_tower', None)
13
+ print(vision_tower_cfg) # Debug print of the config being used
14
+
15
+ # Check for the Qwen2.5-VL vision model in tower name
16
+ if "qwen2.5-vl" in vision_tower_name.lower():
17
+ return Qwen2_5_VlVisionTower(vision_tower_name, args=vision_tower_cfg, **kwargs)
18
+
19
+ # Raise a clear error for unknown towers
20
+ raise ValueError(f'Unknown vision tower: {vision_tower_name}')
21
+
22
+ def build_vision_tower_aux(vision_tower_cfg, **kwargs):
23
+ """
24
+ Use model config to construct the auxiliary (helper) vision tower.
25
+
26
+ vision_tower_cfg: should have attribute mm_vision_tower_aux
27
+ Returns: instance of configured auxiliary vision backbone
28
+ """
29
+ vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', None)
30
+ # Optionally print config for debugging
31
+ # print(vision_tower_cfg)
32
+
33
+ # Check for the DaViT auxiliary vision model in tower name
34
+ if 'davit' in vision_tower_aux.lower():
35
+ return DavitVisionTowerAux(vision_tower_aux, args=vision_tower_cfg, **kwargs)
36
+
37
+ # Raise a clear error if tower type is unknown
38
+ raise ValueError(f'Unknown aux vision tower: {vision_tower_aux}')
vlm_fo1/model/multimodal_encoder/davit/configs.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model_configs = {
3
+ "davit-base": {
4
+ "depths": [
5
+ 1,
6
+ 1,
7
+ 9,
8
+ 1
9
+ ],
10
+ "dim_embed": [
11
+ 128,
12
+ 256,
13
+ 512,
14
+ 1024
15
+ ],
16
+ "drop_path_rate": 0.1,
17
+ "enable_checkpoint": True,
18
+ "image_feature_source": [
19
+ "spatial_avg_pool",
20
+ "temporal_avg_pool"
21
+ ],
22
+ "image_pos_embed": {
23
+ "max_pos_embeddings": 50,
24
+ "type": "learned_abs_2d"
25
+ },
26
+ "num_groups": [
27
+ 4,
28
+ 8,
29
+ 16,
30
+ 32
31
+ ],
32
+ "num_heads": [
33
+ 4,
34
+ 8,
35
+ 16,
36
+ 32
37
+ ],
38
+ "patch_padding": [
39
+ 3,
40
+ 1,
41
+ 1,
42
+ 1
43
+ ],
44
+ "patch_prenorm": [
45
+ False,
46
+ True,
47
+ True,
48
+ True
49
+ ],
50
+ "patch_size": [
51
+ 7,
52
+ 3,
53
+ 3,
54
+ 3
55
+ ],
56
+ "patch_stride": [
57
+ 4,
58
+ 2,
59
+ 2,
60
+ 2
61
+ ],
62
+ "projection_dim": 768,
63
+ "transformers_version": "4.41.2",
64
+ "visual_temporal_embedding": {
65
+ "max_temporal_embeddings": 100,
66
+ "type": "COSINE"
67
+ },
68
+ "window_size": 12
69
+ },
70
+ "davit-large": {
71
+ "depths": [
72
+ 1,
73
+ 1,
74
+ 9,
75
+ 1
76
+ ],
77
+ "dim_embed": [
78
+ 256,
79
+ 512,
80
+ 1024,
81
+ 2048
82
+ ],
83
+ "drop_path_rate": 0.1,
84
+ "enable_checkpoint": True,
85
+ "image_feature_source": [
86
+ "spatial_avg_pool",
87
+ "temporal_avg_pool"
88
+ ],
89
+ "image_pos_embed": {
90
+ "max_pos_embeddings": 50,
91
+ "type": "learned_abs_2d"
92
+ },
93
+ "num_groups": [
94
+ 8,
95
+ 16,
96
+ 32,
97
+ 64
98
+ ],
99
+ "num_heads": [
100
+ 8,
101
+ 16,
102
+ 32,
103
+ 64
104
+ ],
105
+ "patch_padding": [
106
+ 3,
107
+ 1,
108
+ 1,
109
+ 1
110
+ ],
111
+ "patch_prenorm": [
112
+ False,
113
+ True,
114
+ True,
115
+ True
116
+ ],
117
+ "patch_size": [
118
+ 7,
119
+ 3,
120
+ 3,
121
+ 3
122
+ ],
123
+ "patch_stride": [
124
+ 4,
125
+ 2,
126
+ 2,
127
+ 2
128
+ ],
129
+ "projection_dim": 1024,
130
+ "transformers_version": "4.41.2",
131
+ "visual_temporal_embedding": {
132
+ "max_temporal_embeddings": 100,
133
+ "type": "COSINE"
134
+ },
135
+ "window_size": 12
136
+ }
137
+ }
138
+
139
+ img_cfg = {
140
+ "do_resize": True,
141
+ "size": {
142
+ "height": 768,
143
+ "width":768
144
+ },
145
+ "resample": 3,
146
+ "do_center_crop": False,
147
+ "do_rescale": True,
148
+ "do_normalize": True,
149
+ "image_mean": [0.485, 0.456, 0.406],
150
+ "image_std": [0.229, 0.224, 0.225],
151
+ "do_convert_rgb": True
152
+ }
vlm_fo1/model/multimodal_encoder/davit/configuration_davit.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from transformers import AutoConfig
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ class DavitConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
26
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
27
+ defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+ Args:
33
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
34
+ The dropout rate of the drop path layer.
35
+ patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
36
+ The patch size of the image.
37
+ patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
38
+ The patch stride of the image.
39
+ patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
40
+ The patch padding of the image.
41
+ patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
42
+ Whether to apply layer normalization before the patch embedding layer.
43
+ enable_checkpoint (`bool`, *optional*, defaults to False):
44
+ Whether to enable checkpointing.
45
+ dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
46
+ The dimension of the embedding layer.
47
+ num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
48
+ The number of attention heads.
49
+ num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
50
+ The number of groups.
51
+ depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
52
+ The depth of the model.
53
+ window_size (`int`, *optional*, defaults to 12):
54
+ The window size of the model.
55
+ projection_dim (`int`, *optional*, defaults to 1024):
56
+ The dimension of the projection layer.
57
+ visual_temporal_embedding (`dict`, *optional*):
58
+ The configuration of the visual temporal embedding.
59
+ image_pos_embed (`dict`, *optional*):
60
+ The configuration of the image position embedding.
61
+ image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
62
+ The source of the image feature.
63
+ Example:
64
+
65
+ ```python
66
+ >>> from transformers import Florence2VisionConfig, Florence2VisionModel
67
+
68
+ >>> # Initializing a Florence2 Vision style configuration
69
+ >>> configuration = Florence2VisionConfig()
70
+
71
+ >>> # Initializing a model (with random weights)
72
+ >>> model = Florence2VisionModel(configuration)
73
+
74
+ >>> # Accessing the model configuration
75
+ >>> configuration = model.config
76
+ ```"""
77
+
78
+ model_type = "florence2_vision"
79
+ keys_to_ignore_at_inference = ["past_key_values"]
80
+
81
+ def __init__(
82
+ self,
83
+ drop_path_rate=0.1,
84
+ patch_size=[7, 3, 3, 3],
85
+ patch_stride=[4, 2, 2, 2],
86
+ patch_padding=[3, 1, 1, 1],
87
+ patch_prenorm=[False, True, True, True],
88
+ enable_checkpoint=False,
89
+ dim_embed=[256, 512, 1024, 2048],
90
+ num_heads=[8, 16, 32, 64],
91
+ num_groups=[8, 16, 32, 64],
92
+ depths=[1, 1, 9, 1],
93
+ window_size=12,
94
+ projection_dim=1024,
95
+ visual_temporal_embedding=None,
96
+ image_pos_embed=None,
97
+ image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
98
+ **kwargs,
99
+ ):
100
+ self.drop_path_rate = drop_path_rate
101
+ self.patch_size = patch_size
102
+ self.patch_stride = patch_stride
103
+ self.patch_padding = patch_padding
104
+ self.patch_prenorm = patch_prenorm
105
+ self.enable_checkpoint = enable_checkpoint
106
+ self.dim_embed = dim_embed
107
+ self.num_heads = num_heads
108
+ self.num_groups = num_groups
109
+ self.depths = depths
110
+ self.window_size = window_size
111
+ self.projection_dim = projection_dim
112
+ self.visual_temporal_embedding = visual_temporal_embedding
113
+ self.image_pos_embed = image_pos_embed
114
+ self.image_feature_source = image_feature_source
115
+
116
+ super().__init__(**kwargs)
117
+
118
+
119
+