Spaces:
Running
on
Zero
Running
on
Zero
PengLiu
commited on
Commit
·
6302644
1
Parent(s):
c4a1381
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +9 -0
- demo/gradio_demo2.py +369 -0
- demo/sam3_examples/init.py +0 -0
- detect_tools/upn/__init__.py +45 -0
- detect_tools/upn/builder.py +39 -0
- detect_tools/upn/configs/upn_large.py +73 -0
- detect_tools/upn/inference_wrapper.py +237 -0
- detect_tools/upn/models/architecture/__init__.py +4 -0
- detect_tools/upn/models/architecture/deformable_transformer.py +336 -0
- detect_tools/upn/models/architecture/upn_model.py +343 -0
- detect_tools/upn/models/backbone/__init__.py +4 -0
- detect_tools/upn/models/backbone/swin.py +814 -0
- detect_tools/upn/models/backbone/wrapper.py +297 -0
- detect_tools/upn/models/decoder/__init__.py +3 -0
- detect_tools/upn/models/decoder/upn_decoder.py +378 -0
- detect_tools/upn/models/encoder/__init__.py +3 -0
- detect_tools/upn/models/encoder/upn_encoder.py +288 -0
- detect_tools/upn/models/module/__init__.py +5 -0
- detect_tools/upn/models/module/contrastive.py +29 -0
- detect_tools/upn/models/module/mlp.py +18 -0
- detect_tools/upn/models/module/nested_tensor.py +199 -0
- detect_tools/upn/models/utils/__init__.py +23 -0
- detect_tools/upn/models/utils/detr_utils.py +415 -0
- detect_tools/upn/ops/functions/__init__.py +10 -0
- detect_tools/upn/ops/functions/ms_deform_attn_func.py +61 -0
- detect_tools/upn/ops/modules/__init__.py +9 -0
- detect_tools/upn/ops/modules/ms_deform_attn.py +204 -0
- detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py +130 -0
- detect_tools/upn/ops/setup.py +73 -0
- detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
- detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
- detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
- detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
- detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
- detect_tools/upn/ops/src/ms_deform_attn.h +62 -0
- detect_tools/upn/ops/src/vision.cpp +16 -0
- detect_tools/upn/ops/test.py +89 -0
- detect_tools/upn/requirments.txt +1 -0
- detect_tools/upn/transforms/transform.py +142 -0
- requirements.txt +19 -0
- resources/__init__.py +0 -0
- vlm_fo1/__init__.py +1 -0
- vlm_fo1/constants.py +29 -0
- vlm_fo1/mm_utils.py +660 -0
- vlm_fo1/model/__init__.py +1 -0
- vlm_fo1/model/builder.py +89 -0
- vlm_fo1/model/language_model/omchat_qwen2_5_vl.py +576 -0
- vlm_fo1/model/multimodal_encoder/__init__.py +0 -0
- vlm_fo1/model/multimodal_encoder/base_encoder.py +33 -0
- vlm_fo1/model/multimodal_encoder/builder.py +38 -0
app.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
|
| 4 |
+
@spaces.GPU
|
| 5 |
+
def greet(name):
|
| 6 |
+
return "Hello " + name + "!!"
|
| 7 |
+
|
| 8 |
+
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 9 |
+
demo.launch()
|
demo/gradio_demo2.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
+
import re
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from skimage.measure import label, regionprops
|
| 8 |
+
from skimage.morphology import binary_dilation, disk
|
| 9 |
+
from sam3.model_builder import build_sam3_image_model
|
| 10 |
+
from sam3.model.sam3_image_processor import Sam3Processor
|
| 11 |
+
from sam3.visualization_utils import plot_bbox, plot_mask, COLORS
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
|
| 14 |
+
from detect_tools.upn import UPNWrapper
|
| 15 |
+
from vlm_fo1.model.builder import load_pretrained_model
|
| 16 |
+
from vlm_fo1.mm_utils import (
|
| 17 |
+
prepare_inputs,
|
| 18 |
+
extract_predictions_to_indexes,
|
| 19 |
+
)
|
| 20 |
+
from vlm_fo1.task_templates import *
|
| 21 |
+
import torch
|
| 22 |
+
import os
|
| 23 |
+
from copy import deepcopy
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
EXAMPLES = [
|
| 27 |
+
["demo/sam3_examples/00000-72.jpg","airplane with letter AE on its body"],
|
| 28 |
+
["demo/sam3_examples/00000-32.jpg","the lying cat which is not black"],
|
| 29 |
+
["demo/sam3_examples/00000-22.jpg","person wearing a black top"],
|
| 30 |
+
["demo/sam3_examples/000000378453.jpg", "zebra inside the mud puddle"],
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_valid_examples():
|
| 35 |
+
valid_examples = []
|
| 36 |
+
demo_dir = os.path.dirname(os.path.abspath(__file__))
|
| 37 |
+
for example in EXAMPLES:
|
| 38 |
+
img_path = example[0]
|
| 39 |
+
full_path = os.path.join(demo_dir, img_path)
|
| 40 |
+
if os.path.exists(full_path):
|
| 41 |
+
valid_examples.append([
|
| 42 |
+
full_path,
|
| 43 |
+
example[1],
|
| 44 |
+
])
|
| 45 |
+
elif os.path.exists(img_path):
|
| 46 |
+
valid_examples.append([
|
| 47 |
+
img_path,
|
| 48 |
+
example[1],
|
| 49 |
+
])
|
| 50 |
+
return valid_examples
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def detect_model_upn(image, threshold=0.3):
|
| 54 |
+
proposals = upn_model.inference(image)
|
| 55 |
+
filtered_proposals = upn_model.filter(proposals, min_score=threshold)
|
| 56 |
+
picked_proposals = filtered_proposals['original_xyxy_boxes'][0][:100]
|
| 57 |
+
return picked_proposals
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def detect_model_sam3(image, text, threshold=0.3):
|
| 61 |
+
inference_state = sam3_processor.set_image(image)
|
| 62 |
+
output = sam3_processor.set_text_prompt(state=inference_state, prompt=text)
|
| 63 |
+
boxes, scores, masks = output["boxes"], output["scores"], output["masks"]
|
| 64 |
+
sorted_indices = torch.argsort(scores, descending=True)
|
| 65 |
+
boxes = boxes[sorted_indices][:100, :]
|
| 66 |
+
scores = scores[sorted_indices][:100]
|
| 67 |
+
masks = masks[sorted_indices][:100]
|
| 68 |
+
|
| 69 |
+
output = {
|
| 70 |
+
"boxes": boxes,
|
| 71 |
+
"scores": scores,
|
| 72 |
+
"masks": masks,
|
| 73 |
+
}
|
| 74 |
+
return boxes.tolist(), scores.tolist(), masks.tolist(), output
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def multimodal_model(image, bboxes, text, scores=None):
|
| 78 |
+
if len(bboxes) == 0:
|
| 79 |
+
return None, {}, []
|
| 80 |
+
|
| 81 |
+
if '<image>' in text:
|
| 82 |
+
print(text)
|
| 83 |
+
parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()]
|
| 84 |
+
print(parts)
|
| 85 |
+
content = []
|
| 86 |
+
for part in parts:
|
| 87 |
+
if part == '<image>':
|
| 88 |
+
content.append({"type": "image_url", "image_url": {"url": image}})
|
| 89 |
+
else:
|
| 90 |
+
content.append({"type": "text", "text": part})
|
| 91 |
+
else:
|
| 92 |
+
content = [{
|
| 93 |
+
"type": "image_url",
|
| 94 |
+
"image_url": {
|
| 95 |
+
"url": image
|
| 96 |
+
}
|
| 97 |
+
}, {
|
| 98 |
+
"type": "text",
|
| 99 |
+
"text": text
|
| 100 |
+
}]
|
| 101 |
+
|
| 102 |
+
messages = [
|
| 103 |
+
{
|
| 104 |
+
"role": "user",
|
| 105 |
+
"content": content,
|
| 106 |
+
"bbox_list": bboxes
|
| 107 |
+
}
|
| 108 |
+
]
|
| 109 |
+
generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages,
|
| 110 |
+
max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False, image_size=1024)
|
| 111 |
+
with torch.inference_mode():
|
| 112 |
+
output_ids = model.generate(**generation_kwargs)
|
| 113 |
+
outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
|
| 114 |
+
print("========output========\n", outputs)
|
| 115 |
+
|
| 116 |
+
if '<ground>' in outputs:
|
| 117 |
+
prediction_dict = extract_predictions_to_indexes(outputs)
|
| 118 |
+
else:
|
| 119 |
+
match_pattern = r"<region(\d+)>"
|
| 120 |
+
matches = re.findall(match_pattern, outputs)
|
| 121 |
+
prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
|
| 122 |
+
|
| 123 |
+
ans_bbox_json = []
|
| 124 |
+
ans_bbox_list = []
|
| 125 |
+
for k, v in prediction_dict.items():
|
| 126 |
+
for box_index in v:
|
| 127 |
+
box_index = int(box_index)
|
| 128 |
+
if box_index < len(bboxes):
|
| 129 |
+
current_bbox = bboxes[box_index]
|
| 130 |
+
item = {
|
| 131 |
+
"region_index": f"<region{box_index}>",
|
| 132 |
+
"xmin": current_bbox[0],
|
| 133 |
+
"ymin": current_bbox[1],
|
| 134 |
+
"xmax": current_bbox[2],
|
| 135 |
+
"ymax": current_bbox[3],
|
| 136 |
+
"label": k,
|
| 137 |
+
}
|
| 138 |
+
if scores is not None and box_index < len(scores):
|
| 139 |
+
item["score"] = scores[box_index]
|
| 140 |
+
|
| 141 |
+
ans_bbox_json.append(item)
|
| 142 |
+
ans_bbox_list.append(current_bbox)
|
| 143 |
+
|
| 144 |
+
return outputs, ans_bbox_json, ans_bbox_list
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def draw_sam3_results(img, results):
|
| 148 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 149 |
+
# fig.subplots_adjust(0, 0, 1, 1)
|
| 150 |
+
ax.imshow(img)
|
| 151 |
+
nb_objects = len(results["scores"])
|
| 152 |
+
print(f"found {nb_objects} object(s)")
|
| 153 |
+
for i in range(nb_objects):
|
| 154 |
+
color = COLORS[i % len(COLORS)]
|
| 155 |
+
plot_mask(results["masks"][i].squeeze(0).cpu(), color=color)
|
| 156 |
+
w, h = img.size
|
| 157 |
+
prob = results["scores"][i].item()
|
| 158 |
+
plot_bbox(
|
| 159 |
+
h,
|
| 160 |
+
w,
|
| 161 |
+
results["boxes"][i].cpu(),
|
| 162 |
+
text=f"(id={i}, {prob=:.2f})",
|
| 163 |
+
box_format="XYXY",
|
| 164 |
+
color=color,
|
| 165 |
+
relative_coords=False,
|
| 166 |
+
)
|
| 167 |
+
ax.axis("off")
|
| 168 |
+
fig.tight_layout(pad=0)
|
| 169 |
+
|
| 170 |
+
# Convert matplotlib figure to PIL Image
|
| 171 |
+
fig.canvas.draw()
|
| 172 |
+
buf = fig.canvas.buffer_rgba()
|
| 173 |
+
pil_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf)
|
| 174 |
+
plt.close(fig)
|
| 175 |
+
|
| 176 |
+
return pil_img
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def draw_bboxes_simple(image, bboxes, labels=None):
|
| 180 |
+
image = image.copy()
|
| 181 |
+
draw = ImageDraw.Draw(image)
|
| 182 |
+
|
| 183 |
+
for bbox in bboxes:
|
| 184 |
+
draw.rectangle(bbox, outline="red", width=3)
|
| 185 |
+
return image
|
| 186 |
+
|
| 187 |
+
@spaces.GPU
|
| 188 |
+
def process(image, prompt, threshold=0.3):
|
| 189 |
+
if image is None:
|
| 190 |
+
error_msg = "Error: Please upload an image or select a valid example."
|
| 191 |
+
print(f"Error: image is None, original input type: {type(image)}")
|
| 192 |
+
return None, None, None, None, [], []
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
image = image.convert('RGB')
|
| 196 |
+
except Exception as e:
|
| 197 |
+
error_msg = f"Error: Cannot process image - {str(e)}"
|
| 198 |
+
return None, None, None, None, [], []
|
| 199 |
+
|
| 200 |
+
# --- SAM3 Pipeline ---
|
| 201 |
+
print("Running SAM3 Pipeline...")
|
| 202 |
+
sam3_bboxes, sam3_scores, masks, sam3_output = detect_model_sam3(image, prompt, threshold)
|
| 203 |
+
|
| 204 |
+
# Generate SAM3 outputs (Directly from SAM3, no VLM-FO1)
|
| 205 |
+
sam3_detection_image = draw_sam3_results(image, sam3_output)
|
| 206 |
+
|
| 207 |
+
sam3_annotated_bboxes = []
|
| 208 |
+
sam3_ans_bbox_json = []
|
| 209 |
+
|
| 210 |
+
img_width, img_height = image.size
|
| 211 |
+
for i, bbox in enumerate(sam3_bboxes):
|
| 212 |
+
xmin = max(0, min(img_width, int(bbox[0])))
|
| 213 |
+
ymin = max(0, min(img_height, int(bbox[1])))
|
| 214 |
+
xmax = max(0, min(img_width, int(bbox[2])))
|
| 215 |
+
ymax = max(0, min(img_height, int(bbox[3])))
|
| 216 |
+
score = sam3_scores[i]
|
| 217 |
+
|
| 218 |
+
# Format label with score
|
| 219 |
+
label_text = f"{prompt} {score:.2f}"
|
| 220 |
+
|
| 221 |
+
sam3_annotated_bboxes.append(
|
| 222 |
+
((xmin, ymin, xmax, ymax), label_text)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
sam3_ans_bbox_json.append({
|
| 226 |
+
"region_index": i,
|
| 227 |
+
"xmin": bbox[0],
|
| 228 |
+
"ymin": bbox[1],
|
| 229 |
+
"xmax": bbox[2],
|
| 230 |
+
"ymax": bbox[3],
|
| 231 |
+
"label": prompt,
|
| 232 |
+
"score": score
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
sam3_annotated_image = (image, sam3_annotated_bboxes)
|
| 236 |
+
|
| 237 |
+
# --- UPN Pipeline ---
|
| 238 |
+
print("Running UPN Pipeline...")
|
| 239 |
+
upn_bboxes = detect_model_upn(image, threshold=0.3) # Use default threshold for UPN
|
| 240 |
+
|
| 241 |
+
fo1_prompt_upn = OD_template.format(prompt)
|
| 242 |
+
upn_bboxes = upn_bboxes[::-1]
|
| 243 |
+
upn_ans, upn_ans_bbox_json, upn_ans_bbox_list = multimodal_model(image, upn_bboxes, fo1_prompt_upn)
|
| 244 |
+
|
| 245 |
+
upn_detection_image = draw_bboxes_simple(image, upn_bboxes)
|
| 246 |
+
|
| 247 |
+
upn_annotated_bboxes = []
|
| 248 |
+
if len(upn_ans_bbox_json) > 0:
|
| 249 |
+
img_width, img_height = image.size
|
| 250 |
+
for item in upn_ans_bbox_json:
|
| 251 |
+
xmin = max(0, min(img_width, int(item['xmin'])))
|
| 252 |
+
ymin = max(0, min(img_height, int(item['ymin'])))
|
| 253 |
+
xmax = max(0, min(img_width, int(item['xmax'])))
|
| 254 |
+
ymax = max(0, min(img_height, int(item['ymax'])))
|
| 255 |
+
upn_annotated_bboxes.append(
|
| 256 |
+
((xmin, ymin, xmax, ymax), item['label'])
|
| 257 |
+
)
|
| 258 |
+
upn_annotated_image = (image, upn_annotated_bboxes)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
return sam3_annotated_image, sam3_detection_image, \
|
| 262 |
+
upn_annotated_image, upn_detection_image, upn_ans_bbox_json
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def update_btn(is_processing):
|
| 266 |
+
if is_processing:
|
| 267 |
+
return gr.update(value="Processing...", interactive=False)
|
| 268 |
+
else:
|
| 269 |
+
return gr.update(value="Submit", interactive=True)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def launch_demo():
|
| 273 |
+
with gr.Blocks() as demo:
|
| 274 |
+
gr.Markdown("# 🚀 VLM-FO1 vs SAM3 Demo")
|
| 275 |
+
gr.Markdown("""
|
| 276 |
+
### 📋 Instructions
|
| 277 |
+
Compare the detection performance of **SAM3** vs **VLM-FO1**.
|
| 278 |
+
|
| 279 |
+
**How it works**
|
| 280 |
+
1. Upload or pick an example image.
|
| 281 |
+
2. Describe the target object in natural language.
|
| 282 |
+
3. Hit **Submit** to run both pipelines.
|
| 283 |
+
""")
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
with gr.Column():
|
| 287 |
+
img_input_draw = gr.Image(
|
| 288 |
+
label="Image Input",
|
| 289 |
+
type="pil",
|
| 290 |
+
sources=['upload'],
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
gr.Markdown("### Prompt")
|
| 294 |
+
|
| 295 |
+
prompt_input = gr.Textbox(
|
| 296 |
+
label="Label Prompt",
|
| 297 |
+
lines=2,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
| 301 |
+
|
| 302 |
+
examples = gr.Examples(
|
| 303 |
+
examples=EXAMPLES,
|
| 304 |
+
inputs=[img_input_draw, prompt_input],
|
| 305 |
+
label="Click to load example",
|
| 306 |
+
examples_per_page=5
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
with gr.Column():
|
| 310 |
+
gr.Markdown("### SAM3 Result")
|
| 311 |
+
with gr.Accordion("SAM3 Masks & Boxes", open=False):
|
| 312 |
+
sam3_detection_output = gr.Image(label="SAM3 Visualization", height=300)
|
| 313 |
+
|
| 314 |
+
sam3_final_output = gr.AnnotatedImage(label="SAM3 Detections", height=400)
|
| 315 |
+
# sam3_json_output = gr.JSON(label="SAM3 Output Data")
|
| 316 |
+
|
| 317 |
+
with gr.Column():
|
| 318 |
+
gr.Markdown("### VLM-FO1 Result")
|
| 319 |
+
with gr.Accordion("Bboxes Proposals", open=False):
|
| 320 |
+
upn_detection_output = gr.Image(label="Bboxes", height=300)
|
| 321 |
+
|
| 322 |
+
upn_final_output = gr.AnnotatedImage(label="VLM-FO1 Final", height=400)
|
| 323 |
+
upn_json_output = gr.JSON(label="VLM-FO1 Details")
|
| 324 |
+
|
| 325 |
+
submit_btn.click(
|
| 326 |
+
update_btn,
|
| 327 |
+
inputs=[gr.State(True)],
|
| 328 |
+
outputs=[submit_btn],
|
| 329 |
+
queue=False
|
| 330 |
+
).then(
|
| 331 |
+
process,
|
| 332 |
+
inputs=[img_input_draw, prompt_input],
|
| 333 |
+
outputs=[
|
| 334 |
+
sam3_final_output, sam3_detection_output,
|
| 335 |
+
upn_final_output, upn_detection_output, upn_json_output
|
| 336 |
+
],
|
| 337 |
+
queue=True
|
| 338 |
+
).then(
|
| 339 |
+
update_btn,
|
| 340 |
+
inputs=[gr.State(False)],
|
| 341 |
+
outputs=[submit_btn],
|
| 342 |
+
queue=False
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
return demo
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
import os
|
| 349 |
+
exit_code = os.system(f"wget -c https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3.pt")
|
| 350 |
+
|
| 351 |
+
model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01'
|
| 352 |
+
# sam3_model_path = './resources/sam3/sam3.pt'
|
| 353 |
+
upn_ckpt_path = "./resources/upn_large.pth"
|
| 354 |
+
|
| 355 |
+
# Load FO1
|
| 356 |
+
tokenizer, model, image_processors = load_pretrained_model(
|
| 357 |
+
model_path=model_path,
|
| 358 |
+
device="cuda:0",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Load SAM3
|
| 362 |
+
sam3_model = build_sam3_image_model(checkpoint_path='./sam3.pt', device="cuda",bpe_path='/home/user/app/resources/bpe_simple_vocab_16e6.txt.gz')
|
| 363 |
+
sam3_processor = Sam3Processor(sam3_model, confidence_threshold=0.0, device="cuda")
|
| 364 |
+
|
| 365 |
+
# Load UPN
|
| 366 |
+
upn_model = UPNWrapper(upn_ckpt_path)
|
| 367 |
+
|
| 368 |
+
demo = launch_demo()
|
| 369 |
+
demo.launch()
|
demo/sam3_examples/init.py
ADDED
|
File without changes
|
detect_tools/upn/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import models
|
| 2 |
+
from .builder import (
|
| 3 |
+
ARCHITECTURES,
|
| 4 |
+
BACKBONES,
|
| 5 |
+
DECODERS,
|
| 6 |
+
ENCODERS,
|
| 7 |
+
POS_EMBEDDINGS,
|
| 8 |
+
build_architecture,
|
| 9 |
+
build_backbone,
|
| 10 |
+
build_decoder,
|
| 11 |
+
build_encoder,
|
| 12 |
+
build_position_embedding,
|
| 13 |
+
)
|
| 14 |
+
from .inference_wrapper import UPNWrapper
|
| 15 |
+
from .models.architecture import *
|
| 16 |
+
from .models.backbone import *
|
| 17 |
+
from .models.decoder import *
|
| 18 |
+
from .models.encoder import *
|
| 19 |
+
from .models.module import *
|
| 20 |
+
from .models.utils import *
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"BACKBONES",
|
| 24 |
+
"POS_EMBEDDINGS",
|
| 25 |
+
"ENCODERS",
|
| 26 |
+
"DECODERS",
|
| 27 |
+
"ARCHITECTURES",
|
| 28 |
+
"build_backbone",
|
| 29 |
+
"build_position_embedding",
|
| 30 |
+
"build_encoder",
|
| 31 |
+
"build_decoder",
|
| 32 |
+
"build_architecture",
|
| 33 |
+
"UPNWrapper",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
__all__ += (
|
| 37 |
+
models.module.__all__
|
| 38 |
+
+ models.utils.__all__
|
| 39 |
+
+ models.architecture.__all__
|
| 40 |
+
+ models.backbone.__all__
|
| 41 |
+
+ models.encoder.__all__
|
| 42 |
+
+ models.decoder.__all__
|
| 43 |
+
+ models.module.__all__
|
| 44 |
+
+ models.utils.__all__
|
| 45 |
+
)
|
detect_tools/upn/builder.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mmengine import Registry, build_from_cfg
|
| 2 |
+
|
| 3 |
+
BACKBONES = Registry("backbone")
|
| 4 |
+
POS_EMBEDDINGS = Registry("position_embedding")
|
| 5 |
+
FUSERS = Registry("fuser")
|
| 6 |
+
ENCODERS = Registry("encoder")
|
| 7 |
+
DECODERS = Registry("decoder")
|
| 8 |
+
ARCHITECTURES = Registry("architecture")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_backbone(cfg):
|
| 12 |
+
"""Build encoder."""
|
| 13 |
+
return build_from_cfg(cfg, BACKBONES)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_position_embedding(cfg):
|
| 17 |
+
"""Build position embedding."""
|
| 18 |
+
return build_from_cfg(cfg, POS_EMBEDDINGS)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_fuser(cfg):
|
| 22 |
+
"""Build fuser."""
|
| 23 |
+
return build_from_cfg(cfg, FUSERS)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def build_encoder(cfg):
|
| 27 |
+
"""Build encoder."""
|
| 28 |
+
return build_from_cfg(cfg, ENCODERS)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_decoder(cfg):
|
| 32 |
+
"""Build decoder."""
|
| 33 |
+
return build_from_cfg(cfg, DECODERS)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_architecture(cfg):
|
| 37 |
+
"""Build architecture."""
|
| 38 |
+
|
| 39 |
+
return build_from_cfg(cfg, ARCHITECTURES)
|
detect_tools/upn/configs/upn_large.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformer_cfg = dict(
|
| 2 |
+
type="DeformableTransformer",
|
| 3 |
+
num_queries=900,
|
| 4 |
+
encoder_cfg=dict(
|
| 5 |
+
type="UPNEncoder",
|
| 6 |
+
encoder_layer_cfg=dict(
|
| 7 |
+
type="DeformableTransformerEncoderLayer",
|
| 8 |
+
activation="relu",
|
| 9 |
+
d_model=256,
|
| 10 |
+
dropout=0.0,
|
| 11 |
+
d_ffn=2048,
|
| 12 |
+
n_heads=8,
|
| 13 |
+
n_levels=5,
|
| 14 |
+
),
|
| 15 |
+
d_model=256,
|
| 16 |
+
num_layers=6,
|
| 17 |
+
use_checkpoint=False,
|
| 18 |
+
use_transformer_ckpt=False,
|
| 19 |
+
),
|
| 20 |
+
decoder_cfg=dict(
|
| 21 |
+
type="UPNDecoder",
|
| 22 |
+
decoder_layer_cfg=dict(
|
| 23 |
+
type="DeformableTransformerDecoderLayer",
|
| 24 |
+
activation="relu",
|
| 25 |
+
d_model=256,
|
| 26 |
+
n_heads=8,
|
| 27 |
+
dropout=0.0,
|
| 28 |
+
d_ffn=2048,
|
| 29 |
+
n_levels=5,
|
| 30 |
+
),
|
| 31 |
+
d_model=256,
|
| 32 |
+
return_intermediate=True,
|
| 33 |
+
num_layers=6,
|
| 34 |
+
rm_dec_query_scale=True,
|
| 35 |
+
use_detached_boxes_dec_out=False,
|
| 36 |
+
),
|
| 37 |
+
learnable_tgt_init=True,
|
| 38 |
+
random_refpoints_xy=False,
|
| 39 |
+
num_feature_levels=5,
|
| 40 |
+
two_stage_bbox_embed_share=False,
|
| 41 |
+
two_stage_class_embed_share=False,
|
| 42 |
+
two_stage_keep_all_tokens=False,
|
| 43 |
+
two_stage_learn_wh=False,
|
| 44 |
+
two_stage_type="standard",
|
| 45 |
+
binary_query_selection=False,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
vision_backbone = dict(
|
| 49 |
+
type="SwinWrapper",
|
| 50 |
+
backbone_cfg="swin_L_384_22k",
|
| 51 |
+
lr_backbone=1e-05,
|
| 52 |
+
dilation=False,
|
| 53 |
+
return_interm_indices=[0, 1, 2, 3],
|
| 54 |
+
backbone_freeze_keywords=None,
|
| 55 |
+
backbone_ckpt_path=None,
|
| 56 |
+
use_checkpoint=False,
|
| 57 |
+
position_embedding_cfg=dict(
|
| 58 |
+
type="PositionEmbeddingSineHW",
|
| 59 |
+
normalize=True,
|
| 60 |
+
num_pos_feats=128,
|
| 61 |
+
temperatureH=20,
|
| 62 |
+
temperatureW=20,
|
| 63 |
+
),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
model = dict(
|
| 67 |
+
type="UPN",
|
| 68 |
+
vision_backbone_cfg=vision_backbone,
|
| 69 |
+
transformer_cfg=transformer_cfg,
|
| 70 |
+
num_queries=900,
|
| 71 |
+
dec_pred_bbox_embed_share=True,
|
| 72 |
+
dec_pred_class_embed_share=True,
|
| 73 |
+
)
|
detect_tools/upn/inference_wrapper.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, List, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from mmengine import Config
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision.ops import nms
|
| 10 |
+
|
| 11 |
+
import detect_tools.upn.transforms.transform as T
|
| 12 |
+
from detect_tools.upn import build_architecture
|
| 13 |
+
from detect_tools.upn.models.module import nested_tensor_from_tensor_list
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_model(
|
| 17 |
+
ckpt_path: str,
|
| 18 |
+
):
|
| 19 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
| 20 |
+
config_file = f"configs/upn_large.py"
|
| 21 |
+
config_path = os.path.join(current_path, config_file)
|
| 22 |
+
model_cfg = Config.fromfile(config_path).model
|
| 23 |
+
model = build_architecture(model_cfg)
|
| 24 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 25 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class UPNWrapper:
|
| 30 |
+
"""A wrapper class for the UPN model.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
ckpt_path (str): The path to the model checkpoint
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, ckpt_path: str):
|
| 37 |
+
|
| 38 |
+
self.model = build_model(ckpt_path)
|
| 39 |
+
self.model.eval()
|
| 40 |
+
self.model.to("cuda")
|
| 41 |
+
|
| 42 |
+
def inference(
|
| 43 |
+
self,
|
| 44 |
+
image: List[Union[str, Image.Image]],
|
| 45 |
+
prompt_type: str = 'fine_grained_prompt',
|
| 46 |
+
):
|
| 47 |
+
"""Single image prediction.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
image List[Union[str, Image.Image]]: A list of image path or
|
| 51 |
+
PIL.Image.Image object.
|
| 52 |
+
prompt_type (str): The type of prompt to use for the prediction. Choice in
|
| 53 |
+
['fine_grained_prompt', 'coarse_grained_prompt'].
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Dict: Return dict in format:
|
| 57 |
+
{
|
| 58 |
+
"original_xyxy_boxes": (np.ndarray): Original prediction boxes in shape (batch_size, 900, 4),
|
| 59 |
+
"scores": (np.ndarray): Score in shape (batch_size, N)
|
| 60 |
+
}
|
| 61 |
+
"""
|
| 62 |
+
if not isinstance(image, list):
|
| 63 |
+
image = [image]
|
| 64 |
+
input_images, image_sizes = self.construct_input(image)
|
| 65 |
+
outputs = self._inference(input_images, prompt_type)
|
| 66 |
+
post_processed_outputs = self.postprocess(outputs, image_sizes)
|
| 67 |
+
return post_processed_outputs
|
| 68 |
+
|
| 69 |
+
def _inference(self, input_images: List[torch.Tensor], prompt_type: str):
|
| 70 |
+
"""Inference for T-Rex2
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
input_images (List[torch.Tensor]): Transformed Image
|
| 74 |
+
|
| 75 |
+
Retunrs:
|
| 76 |
+
(Dict): Return dict with keys:
|
| 77 |
+
- query_features: (torch.Tensor): Query features in shape (batch_size, N, 256)
|
| 78 |
+
- pred_boxes: (torch.Tensor): Normalized prediction boxes in shape (batch_size, N, 4),
|
| 79 |
+
in cxcywh format
|
| 80 |
+
"""
|
| 81 |
+
input_images = nested_tensor_from_tensor_list(input_images)
|
| 82 |
+
input_images = input_images.to("cuda")
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
outputs = self.model(input_images, prompt_type)
|
| 85 |
+
return outputs
|
| 86 |
+
|
| 87 |
+
def construct_input(self, image: List[Union[str, Image.Image]]):
|
| 88 |
+
"""Construct input for the model
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image (image: Union[List[Union[str, Image.Image]], torch.Tensor]): A list of image path or
|
| 92 |
+
PIL.Image.Image object. If the length of the list is more than 1, the model w`ill
|
| 93 |
+
perform batch inference.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Tuple[torch.Tensor, List[List[int]]]: A tuple containing the
|
| 97 |
+
input images, and the sizes of the input images.
|
| 98 |
+
"""
|
| 99 |
+
input_images = []
|
| 100 |
+
image_sizes = []
|
| 101 |
+
for _, img in enumerate(image):
|
| 102 |
+
if isinstance(img, str):
|
| 103 |
+
img = Image.open(img)
|
| 104 |
+
elif isinstance(img, Image.Image):
|
| 105 |
+
img = img
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"image must be either a string or a PIL.Image.Image object"
|
| 109 |
+
)
|
| 110 |
+
W, H = img.size
|
| 111 |
+
image_sizes.append([H, W])
|
| 112 |
+
# add image in tensor format
|
| 113 |
+
input_images.append(self.transform_image(img))
|
| 114 |
+
return input_images, image_sizes
|
| 115 |
+
|
| 116 |
+
def transform_image(self, image_pil: Image) -> Image:
|
| 117 |
+
"""apply a set of transformations to a cv2 load image.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
image_path (str): The path to the image file.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tuple[PIL.Image, torch.Tensor]: A tuple containing the original PIL Image and the
|
| 124 |
+
transformed image as a PyTorch tensor.
|
| 125 |
+
"""
|
| 126 |
+
transform = T.Compose(
|
| 127 |
+
[
|
| 128 |
+
T.RandomResize([800], max_size=1333),
|
| 129 |
+
T.ToTensor(),
|
| 130 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 131 |
+
]
|
| 132 |
+
)
|
| 133 |
+
transformed_image, _ = transform(image_pil, None) # 3, h, w
|
| 134 |
+
return transformed_image
|
| 135 |
+
|
| 136 |
+
def postprocess(
|
| 137 |
+
self,
|
| 138 |
+
outputs: Dict[str, torch.Tensor],
|
| 139 |
+
image_pil_sizes: List[List[int]] = None,
|
| 140 |
+
):
|
| 141 |
+
boxes = outputs["pred_boxes"].cpu()
|
| 142 |
+
scores = (
|
| 143 |
+
outputs["pred_logits"].sigmoid().cpu() if "pred_logits" in outputs else None
|
| 144 |
+
)
|
| 145 |
+
normalized_xyxy_boxes = []
|
| 146 |
+
original_xyxy_boxes = []
|
| 147 |
+
for batch_idx, (H, W) in enumerate(image_pil_sizes):
|
| 148 |
+
batch_boxes = boxes[batch_idx] # (num_queries, 4)
|
| 149 |
+
# from (cx, cy, w, h) to (x1, y1, x2, y2)
|
| 150 |
+
batch_boxes[:, 0] = batch_boxes[:, 0] - batch_boxes[:, 2] / 2
|
| 151 |
+
batch_boxes[:, 1] = batch_boxes[:, 1] - batch_boxes[:, 3] / 2
|
| 152 |
+
batch_boxes[:, 2] = batch_boxes[:, 0] + batch_boxes[:, 2]
|
| 153 |
+
batch_boxes[:, 3] = batch_boxes[:, 1] + batch_boxes[:, 3]
|
| 154 |
+
normalized_xyxy_boxes.append(copy.deepcopy(batch_boxes))
|
| 155 |
+
# scale boxes
|
| 156 |
+
original_boxes = (
|
| 157 |
+
batch_boxes.clone()
|
| 158 |
+
) # Copy the normalized boxes to scale to original sizes
|
| 159 |
+
original_boxes[:, 0] = original_boxes[:, 0] * W
|
| 160 |
+
original_boxes[:, 1] = original_boxes[:, 1] * H
|
| 161 |
+
original_boxes[:, 2] = original_boxes[:, 2] * W
|
| 162 |
+
original_boxes[:, 3] = original_boxes[:, 3] * H
|
| 163 |
+
original_xyxy_boxes.append(original_boxes)
|
| 164 |
+
|
| 165 |
+
original_xyxy_boxes = torch.stack(original_xyxy_boxes)
|
| 166 |
+
original_xyxy_boxes = original_xyxy_boxes.numpy()
|
| 167 |
+
|
| 168 |
+
# sort everything by score from highest to lowest
|
| 169 |
+
sorted_original_boxes = []
|
| 170 |
+
sorted_scores = []
|
| 171 |
+
for i in range(len(normalized_xyxy_boxes)):
|
| 172 |
+
scores_i = scores[i] if scores is not None else None
|
| 173 |
+
# sort in descending order
|
| 174 |
+
sorted_indices = scores_i.squeeze(-1).argsort(descending=True)
|
| 175 |
+
sorted_original_boxes.append(original_xyxy_boxes[i][sorted_indices])
|
| 176 |
+
sorted_scores.append(scores_i[sorted_indices])
|
| 177 |
+
|
| 178 |
+
original_xyxy_boxes = np.stack(sorted_original_boxes)
|
| 179 |
+
scores = torch.stack(sorted_scores)
|
| 180 |
+
|
| 181 |
+
return dict(
|
| 182 |
+
original_xyxy_boxes=original_xyxy_boxes,
|
| 183 |
+
scores=scores,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def filter(self, result: Dict, min_score: float, nms_value: float = 0.8):
|
| 187 |
+
"""Filter the UPN detection result. Only keep boxes with score above min_score
|
| 188 |
+
and apply Non-Maximum Suppression (NMS) to filter overlapping boxes.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
result (Dict): A dictionary containing detection results with 'original_xyxy_boxes' and 'scores'.
|
| 192 |
+
min_score (float): Minimum score threshold for keeping a box.
|
| 193 |
+
nms_value (float): NMS threshold for filtering boxes.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Dict: Filtered result containing 'original_xyxy_boxes' and 'scores' with the filtered boxes.
|
| 197 |
+
"""
|
| 198 |
+
filtered_result = {"original_xyxy_boxes": [], "scores": []}
|
| 199 |
+
|
| 200 |
+
for boxes, scores in zip(
|
| 201 |
+
np.array(result["original_xyxy_boxes"]), result["scores"].numpy()
|
| 202 |
+
):
|
| 203 |
+
# Filter out boxes with score below min_score
|
| 204 |
+
keep = scores >= min_score
|
| 205 |
+
boxes = boxes[keep[:, 0]]
|
| 206 |
+
scores = scores[keep[:, 0]][:, 0]
|
| 207 |
+
|
| 208 |
+
if len(boxes) == 0:
|
| 209 |
+
return filtered_result
|
| 210 |
+
|
| 211 |
+
# Convert to torch tensors
|
| 212 |
+
boxes = torch.tensor(boxes, dtype=torch.float32)
|
| 213 |
+
scores = torch.tensor(scores, dtype=torch.float32)
|
| 214 |
+
|
| 215 |
+
# Apply Non-Maximum Suppression (NMS)
|
| 216 |
+
if nms_value > 0:
|
| 217 |
+
keep_indices = nms(boxes, scores, nms_value)
|
| 218 |
+
else:
|
| 219 |
+
keep_indices = torch.arange(len(boxes))
|
| 220 |
+
|
| 221 |
+
# Keep only the boxes that passed NMS
|
| 222 |
+
filtered_boxes = boxes[keep_indices].numpy().astype(np.int32)
|
| 223 |
+
filtered_scores = scores[keep_indices].numpy()
|
| 224 |
+
|
| 225 |
+
# Sort the boxes by score in descending order
|
| 226 |
+
sorted_indices = np.argsort(filtered_scores)[::-1]
|
| 227 |
+
filtered_boxes = filtered_boxes[sorted_indices]
|
| 228 |
+
filtered_scores = filtered_scores[sorted_indices]
|
| 229 |
+
|
| 230 |
+
# Round the scores to two decimal places
|
| 231 |
+
filtered_scores = [round(score, 2) for score in filtered_scores]
|
| 232 |
+
|
| 233 |
+
# Store the filtered boxes and scores in the result dictionary
|
| 234 |
+
filtered_result["original_xyxy_boxes"].append(filtered_boxes.tolist())
|
| 235 |
+
filtered_result["scores"].append(filtered_scores)
|
| 236 |
+
|
| 237 |
+
return filtered_result
|
detect_tools/upn/models/architecture/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .deformable_transformer import DeformableTransformer
|
| 2 |
+
from .upn_model import UPN
|
| 3 |
+
|
| 4 |
+
__all__ = ["UPN", "DeformableTransformer"]
|
detect_tools/upn/models/architecture/deformable_transformer.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from detect_tools.upn import ARCHITECTURES, build_decoder, build_encoder
|
| 8 |
+
from detect_tools.upn.models.utils import (gen_encoder_output_proposals,
|
| 9 |
+
inverse_sigmoid)
|
| 10 |
+
from detect_tools.upn.ops.modules import MSDeformAttn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@ARCHITECTURES.register_module()
|
| 14 |
+
class DeformableTransformer(nn.Module):
|
| 15 |
+
"""Implementation of Deformable DETR.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
encoder_cfg (Dict): Config for the TransformerEncoder.
|
| 19 |
+
decoder_cfg (Dict): Config for the TransformerDecoder.
|
| 20 |
+
num_queries (int): Number of queries. This is for matching part. Default: 900.
|
| 21 |
+
d_model (int): Dimension of the model. Default: 256.
|
| 22 |
+
num_feature_levels (int): Number of feature levels. Default: 1.
|
| 23 |
+
binary_query_selection (bool): Whether to use binary query selection. Default: False.
|
| 24 |
+
When using binary query selection, a linear with out channe =1 will be used to select
|
| 25 |
+
topk proposals. Otherwise, we will use ContrastiveAssign to select topk proposals.
|
| 26 |
+
learnable_tgt_init (bool): Whether to use learnable target init. Default: True. If False,
|
| 27 |
+
we will use the topk encoder features as the target init.
|
| 28 |
+
random_refpoints_xy (bool): Whether to use random refpoints xy. This is only used when
|
| 29 |
+
two_stage_type is not 'no'. Default: False. If True, we will use random refpoints xy.
|
| 30 |
+
two_stage_type (str): Type of two stage. Default: 'standard'. Options: 'no', 'standard'
|
| 31 |
+
two_stage_learn_wh (bool): Whether to learn the width and height of anchor boxes. Default: False.
|
| 32 |
+
two_stage_keep_all_tokens (bool): If False, the returned hs_enc, ref_enc, init_box_proposal
|
| 33 |
+
will only be the topk proposals. Otherwise, we will return all the proposals from the
|
| 34 |
+
encoder. Default: False.
|
| 35 |
+
two_stage_bbox_embed_share (bool): Whether to share the bbox embedding between the two stages.
|
| 36 |
+
Default: False.
|
| 37 |
+
two_stage_class_embed_share (bool): Whether to share the class embedding between the two stages.
|
| 38 |
+
rm_self_attn_layers (List[int]): The indices of the decoder layers to remove self-attention.
|
| 39 |
+
Default: None.
|
| 40 |
+
rm_detach (bool): Whether to detach the decoder output. Default: None.
|
| 41 |
+
embed_init_tgt (bool): If true, the target embedding is learnable. Otherwise, we will use
|
| 42 |
+
the topk encoder features as the target init. Default: True.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
encoder_cfg: Dict,
|
| 48 |
+
decoder_cfg: Dict,
|
| 49 |
+
mask_decoder_cfg: Dict = None,
|
| 50 |
+
num_queries: int = 900,
|
| 51 |
+
d_model: int = 256,
|
| 52 |
+
num_feature_levels: int = 4,
|
| 53 |
+
binary_query_selection: bool = False,
|
| 54 |
+
# init query (target)
|
| 55 |
+
learnable_tgt_init=True,
|
| 56 |
+
random_refpoints_xy=False,
|
| 57 |
+
# for two stage
|
| 58 |
+
two_stage_type: str = "standard",
|
| 59 |
+
two_stage_learn_wh: bool = False,
|
| 60 |
+
two_stage_keep_all_tokens: bool = False,
|
| 61 |
+
two_stage_bbox_embed_share: bool = False,
|
| 62 |
+
two_stage_class_embed_share: bool = False,
|
| 63 |
+
# evo of #anchors
|
| 64 |
+
rm_self_attn_layers: List[int] = None,
|
| 65 |
+
# for detach
|
| 66 |
+
rm_detach: bool = None,
|
| 67 |
+
with_encoder_out: bool = True,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.binary_query_selection = binary_query_selection
|
| 71 |
+
self.num_queries = num_queries
|
| 72 |
+
self.num_feature_levels = num_feature_levels
|
| 73 |
+
self.rm_self_attn_layers = rm_self_attn_layers
|
| 74 |
+
self.d_model = d_model
|
| 75 |
+
self.two_stage_bbox_embed_share = two_stage_bbox_embed_share
|
| 76 |
+
self.two_stage_class_embed_share = two_stage_class_embed_share
|
| 77 |
+
|
| 78 |
+
if self.binary_query_selection:
|
| 79 |
+
self.binary_query_selection_layer = nn.Linear(d_model, 1)
|
| 80 |
+
|
| 81 |
+
# build encoder
|
| 82 |
+
self.encoder = build_encoder(encoder_cfg)
|
| 83 |
+
|
| 84 |
+
# build decoder
|
| 85 |
+
self.decoder = build_decoder(decoder_cfg)
|
| 86 |
+
self.num_decoder_layers = self.decoder.num_layers
|
| 87 |
+
|
| 88 |
+
# build sole mask decoder
|
| 89 |
+
if mask_decoder_cfg is not None:
|
| 90 |
+
self.mask_decoder = build_decoder(mask_decoder_cfg)
|
| 91 |
+
else:
|
| 92 |
+
self.mask_decoder = None
|
| 93 |
+
# level embedding
|
| 94 |
+
if num_feature_levels > 1:
|
| 95 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
| 96 |
+
|
| 97 |
+
# learnable target embedding
|
| 98 |
+
self.learnable_tgt_init = learnable_tgt_init
|
| 99 |
+
assert learnable_tgt_init, "why not learnable_tgt_init"
|
| 100 |
+
|
| 101 |
+
self.tgt_embed = nn.Embedding(num_queries, d_model)
|
| 102 |
+
nn.init.normal_(self.tgt_embed.weight.data)
|
| 103 |
+
|
| 104 |
+
# for two stage
|
| 105 |
+
# TODO: this part is really confusing
|
| 106 |
+
self.two_stage_type = two_stage_type
|
| 107 |
+
self.two_stage_learn_wh = two_stage_learn_wh
|
| 108 |
+
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
|
| 109 |
+
assert two_stage_type in [
|
| 110 |
+
"no",
|
| 111 |
+
"standard",
|
| 112 |
+
], "unknown param {} of two_stage_type".format(two_stage_type)
|
| 113 |
+
self.with_encoder_out = with_encoder_out
|
| 114 |
+
if two_stage_type == "standard":
|
| 115 |
+
# anchor selection at the output of encoder
|
| 116 |
+
if with_encoder_out:
|
| 117 |
+
self.enc_output = nn.Linear(d_model, d_model)
|
| 118 |
+
self.enc_output_norm = nn.LayerNorm(d_model)
|
| 119 |
+
|
| 120 |
+
if two_stage_learn_wh:
|
| 121 |
+
# import ipdb; ipdb.set_trace()
|
| 122 |
+
self.two_stage_wh_embedding = nn.Embedding(1, 2)
|
| 123 |
+
else:
|
| 124 |
+
self.two_stage_wh_embedding = None
|
| 125 |
+
|
| 126 |
+
elif two_stage_type == "no":
|
| 127 |
+
self.init_ref_points(
|
| 128 |
+
num_queries, random_refpoints_xy
|
| 129 |
+
) # init self.refpoint_embed
|
| 130 |
+
|
| 131 |
+
self.enc_out_class_embed = None # this will be initialized outside of the model
|
| 132 |
+
self.enc_out_bbox_embed = None # this will be initialized outside of the model
|
| 133 |
+
|
| 134 |
+
# remove some self_attn_layers or rm_detach
|
| 135 |
+
self._reset_parameters()
|
| 136 |
+
|
| 137 |
+
self.rm_self_attn_layers = rm_self_attn_layers
|
| 138 |
+
if rm_self_attn_layers is not None:
|
| 139 |
+
# assert len(rm_self_attn_layers) == num_decoder_layers
|
| 140 |
+
print(
|
| 141 |
+
"Removing the self-attn in {} decoder layers".format(
|
| 142 |
+
rm_self_attn_layers
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
for lid, dec_layer in enumerate(self.decoder.layers):
|
| 146 |
+
if lid in rm_self_attn_layers:
|
| 147 |
+
dec_layer.rm_self_attn_modules()
|
| 148 |
+
|
| 149 |
+
self.rm_detach = rm_detach
|
| 150 |
+
if self.rm_detach:
|
| 151 |
+
assert isinstance(rm_detach, list)
|
| 152 |
+
assert any([i in ["enc_ref", "enc_tgt", "dec"] for i in rm_detach])
|
| 153 |
+
self.decoder.rm_detach = rm_detach
|
| 154 |
+
|
| 155 |
+
def _reset_parameters(self):
|
| 156 |
+
for p in self.parameters():
|
| 157 |
+
if p.dim() > 1:
|
| 158 |
+
nn.init.xavier_uniform_(p)
|
| 159 |
+
for m in self.modules():
|
| 160 |
+
if isinstance(m, MSDeformAttn):
|
| 161 |
+
m._reset_parameters()
|
| 162 |
+
if self.num_feature_levels > 1 and self.level_embed is not None:
|
| 163 |
+
nn.init.normal_(self.level_embed)
|
| 164 |
+
|
| 165 |
+
if self.two_stage_learn_wh:
|
| 166 |
+
nn.init.constant_(
|
| 167 |
+
self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def init_ref_points(self, num_queries: int, random_refpoints_xy: bool = False):
|
| 171 |
+
"""Initialize learnable reference points for each query.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
num_queries (int): number of queries
|
| 175 |
+
random_refpoints_xy (bool, optional): whether to init the refpoints randomly.
|
| 176 |
+
Defaults to False.
|
| 177 |
+
"""
|
| 178 |
+
self.refpoint_embed = nn.Embedding(num_queries, 4)
|
| 179 |
+
if random_refpoints_xy:
|
| 180 |
+
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
|
| 181 |
+
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
|
| 182 |
+
self.refpoint_embed.weight.data[:, :2]
|
| 183 |
+
)
|
| 184 |
+
self.refpoint_embed.weight.data[:, :2].requires_grad = False
|
| 185 |
+
|
| 186 |
+
def get_valid_ratio(self, mask):
|
| 187 |
+
_, H, W = mask.shape
|
| 188 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
| 189 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
| 190 |
+
valid_ratio_h = valid_H.float() / H
|
| 191 |
+
valid_ratio_w = valid_W.float() / W
|
| 192 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
| 193 |
+
return valid_ratio
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
src_flatten: torch.Tensor,
|
| 198 |
+
lvl_pos_embed_flatten: torch.Tensor,
|
| 199 |
+
level_start_index: List[int],
|
| 200 |
+
spatial_shapes: List[torch.Tensor],
|
| 201 |
+
valid_ratios: List[torch.Tensor],
|
| 202 |
+
mask_flatten: torch.Tensor,
|
| 203 |
+
prompt_type: str,
|
| 204 |
+
) -> List[torch.Tensor]:
|
| 205 |
+
"""Forward function."""
|
| 206 |
+
memory = self.encoder(
|
| 207 |
+
src_flatten,
|
| 208 |
+
pos=lvl_pos_embed_flatten,
|
| 209 |
+
level_start_index=level_start_index,
|
| 210 |
+
spatial_shapes=spatial_shapes,
|
| 211 |
+
valid_ratios=valid_ratios,
|
| 212 |
+
key_padding_mask=mask_flatten,
|
| 213 |
+
)
|
| 214 |
+
batch_size = src_flatten.shape[0]
|
| 215 |
+
crop_region_features = torch.zeros(batch_size, 1, self.d_model).to(
|
| 216 |
+
memory.device
|
| 217 |
+
)
|
| 218 |
+
if prompt_type == "fine_grained_prompt":
|
| 219 |
+
crop_region_features = (
|
| 220 |
+
self.fine_grained_prompt.weight[0]
|
| 221 |
+
.unsqueeze(0)
|
| 222 |
+
.unsqueeze(0)
|
| 223 |
+
.repeat(batch_size, 1, 1)
|
| 224 |
+
)
|
| 225 |
+
elif prompt_type == "coarse_grained_prompt":
|
| 226 |
+
crop_region_features = (
|
| 227 |
+
self.coarse_grained_prompt.weight[0]
|
| 228 |
+
.unsqueeze(0)
|
| 229 |
+
.unsqueeze(0)
|
| 230 |
+
.repeat(batch_size, 1, 1)
|
| 231 |
+
)
|
| 232 |
+
pad_mask = torch.ones(batch_size, 1).to(crop_region_features.device).bool()
|
| 233 |
+
self_attn_mask = torch.ones(batch_size, 1, 1).to(crop_region_features.device)
|
| 234 |
+
ref_dict = dict(
|
| 235 |
+
encoded_ref_feature=crop_region_features,
|
| 236 |
+
pad_mask=pad_mask,
|
| 237 |
+
self_attn_mask=self_attn_mask,
|
| 238 |
+
prompt_type="universal_prompt",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
(
|
| 242 |
+
refpoint_embed,
|
| 243 |
+
tgt,
|
| 244 |
+
init_box_proposal,
|
| 245 |
+
) = self.get_two_stage_proposal(memory, mask_flatten, spatial_shapes, ref_dict)
|
| 246 |
+
|
| 247 |
+
hs, references = self.decoder(
|
| 248 |
+
tgt=tgt.transpose(0, 1),
|
| 249 |
+
tgt_key_padding_mask=None,
|
| 250 |
+
memory=memory.transpose(0, 1),
|
| 251 |
+
memory_key_padding_mask=mask_flatten,
|
| 252 |
+
pos=lvl_pos_embed_flatten.transpose(0, 1),
|
| 253 |
+
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
|
| 254 |
+
level_start_index=level_start_index,
|
| 255 |
+
spatial_shapes=spatial_shapes,
|
| 256 |
+
valid_ratios=valid_ratios,
|
| 257 |
+
tgt_mask=None,
|
| 258 |
+
# we ~ the mask . False means use the token; True means pad the token
|
| 259 |
+
)
|
| 260 |
+
hs_enc = ref_enc = None
|
| 261 |
+
return (
|
| 262 |
+
hs,
|
| 263 |
+
references,
|
| 264 |
+
ref_dict,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def get_two_stage_proposal(
|
| 268 |
+
self,
|
| 269 |
+
memory: torch.Tensor,
|
| 270 |
+
mask_flatten: torch.Tensor,
|
| 271 |
+
spatial_shapes: List[torch.Tensor],
|
| 272 |
+
ref_dict: Dict,
|
| 273 |
+
) -> List[torch.Tensor]:
|
| 274 |
+
"""Two stage proposal generation for decoder
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
memory (torch.Tensor): Image encoded feature. [bs, n, 256]
|
| 278 |
+
mask_flatten (torch.Tensor): Flattened mask. [bs, n]
|
| 279 |
+
spatial_shapes (List[torch.Tensor]): Spatial shapes of each feature map. [bs, num_levels, 2]
|
| 280 |
+
refpoint_embed_dn (torch.Tensor): Denosing refpoint embedding. [bs, num_dn_queries, 256]
|
| 281 |
+
tgt_dn (torch.Tensor): Denosing target embedding. [bs, num_dn_queries, 256]
|
| 282 |
+
ref_dict (Dict): A dict containing all kinds of reference image related features.
|
| 283 |
+
"""
|
| 284 |
+
bs = memory.shape[0]
|
| 285 |
+
input_hw = None
|
| 286 |
+
output_memory, output_proposals = gen_encoder_output_proposals(
|
| 287 |
+
memory, mask_flatten, spatial_shapes, input_hw
|
| 288 |
+
)
|
| 289 |
+
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
| 290 |
+
|
| 291 |
+
if self.binary_query_selection: # Unused
|
| 292 |
+
topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1)
|
| 293 |
+
else:
|
| 294 |
+
if ref_dict is not None:
|
| 295 |
+
enc_outputs_class_unselected = self.enc_out_class_embed(
|
| 296 |
+
output_memory, ref_dict
|
| 297 |
+
) # this is not a linear layer for prediction. But contrastive similaryity, shape [B, len_image, len_text]
|
| 298 |
+
else:
|
| 299 |
+
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
|
| 300 |
+
topk_logits = enc_outputs_class_unselected.max(-1)[
|
| 301 |
+
0
|
| 302 |
+
] # shape [B, len_image]
|
| 303 |
+
enc_outputs_coord_unselected = (
|
| 304 |
+
self.enc_out_bbox_embed(output_memory) + output_proposals
|
| 305 |
+
) # (bs, \sum{hw}, 4) unsigmoid
|
| 306 |
+
topk = self.num_queries
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
| 310 |
+
except:
|
| 311 |
+
raise ValueError(f"dadad {topk_logits.shape}")
|
| 312 |
+
|
| 313 |
+
# gather boxes
|
| 314 |
+
refpoint_embed_undetach = torch.gather(
|
| 315 |
+
enc_outputs_coord_unselected,
|
| 316 |
+
1,
|
| 317 |
+
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
| 318 |
+
) # unsigmoid
|
| 319 |
+
refpoint_embed_ = refpoint_embed_undetach.detach()
|
| 320 |
+
init_box_proposal = torch.gather(
|
| 321 |
+
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
| 322 |
+
).sigmoid() # sigmoid
|
| 323 |
+
# gather tgt
|
| 324 |
+
tgt_undetach = torch.gather(
|
| 325 |
+
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
|
| 326 |
+
)
|
| 327 |
+
tgt_ = (
|
| 328 |
+
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
| 329 |
+
) # nq, bs, d_model
|
| 330 |
+
refpoint_embed, tgt = refpoint_embed_, tgt_
|
| 331 |
+
|
| 332 |
+
return (
|
| 333 |
+
refpoint_embed,
|
| 334 |
+
tgt,
|
| 335 |
+
init_box_proposal,
|
| 336 |
+
)
|
detect_tools/upn/models/architecture/upn_model.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Dict, List, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from detect_tools.upn import ARCHITECTURES, build_architecture, build_backbone
|
| 9 |
+
from detect_tools.upn.models.module import (MLP, ContrastiveAssign, NestedTensor,
|
| 10 |
+
nested_tensor_from_tensor_list)
|
| 11 |
+
from detect_tools.upn.models.utils import inverse_sigmoid
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LayerNorm2d(nn.Module):
|
| 15 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 18 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 19 |
+
self.eps = eps
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
u = x.mean(1, keepdim=True)
|
| 23 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 24 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 25 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@ARCHITECTURES.register_module()
|
| 30 |
+
class UPN(nn.Module):
|
| 31 |
+
"""Implementation of UPN"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
vision_backbone_cfg: Dict,
|
| 36 |
+
transformer_cfg: Dict,
|
| 37 |
+
num_queries: int,
|
| 38 |
+
dec_pred_class_embed_share=True,
|
| 39 |
+
dec_pred_bbox_embed_share=True,
|
| 40 |
+
decoder_sa_type="sa",
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
# build vision backbone
|
| 44 |
+
self.backbone = build_backbone(vision_backbone_cfg)
|
| 45 |
+
# build transformer
|
| 46 |
+
self.transformer = build_architecture(transformer_cfg)
|
| 47 |
+
|
| 48 |
+
self.hidden_dim = self.transformer.d_model
|
| 49 |
+
|
| 50 |
+
# for dn training
|
| 51 |
+
self.num_queries = num_queries
|
| 52 |
+
self.num_feature_levels = self.transformer.num_feature_levels
|
| 53 |
+
|
| 54 |
+
# prepare projection layer for vision feature
|
| 55 |
+
self.input_proj = self.prepare_vision_feature_projection_layer(
|
| 56 |
+
self.backbone,
|
| 57 |
+
self.transformer.num_feature_levels,
|
| 58 |
+
self.hidden_dim,
|
| 59 |
+
self.transformer.two_stage_type,
|
| 60 |
+
)
|
| 61 |
+
# prepare prediction head
|
| 62 |
+
self.prepare_prediction_head(
|
| 63 |
+
dec_pred_class_embed_share,
|
| 64 |
+
dec_pred_bbox_embed_share,
|
| 65 |
+
self.hidden_dim,
|
| 66 |
+
self.transformer.num_decoder_layers,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.decoder_sa_type = decoder_sa_type
|
| 70 |
+
assert decoder_sa_type in ["sa", "ca_label", "ca_content"]
|
| 71 |
+
# self.replace_sa_with_double_ca = replace_sa_with_double_ca
|
| 72 |
+
|
| 73 |
+
for layer in self.transformer.decoder.layers:
|
| 74 |
+
layer.label_embedding = None
|
| 75 |
+
self.label_embedding = None
|
| 76 |
+
|
| 77 |
+
# build a unversal token
|
| 78 |
+
self.transformer.fine_grained_prompt = nn.Embedding(1, self.hidden_dim)
|
| 79 |
+
self.transformer.coarse_grained_prompt = nn.Embedding(1, self.hidden_dim)
|
| 80 |
+
|
| 81 |
+
self._reset_parameters()
|
| 82 |
+
|
| 83 |
+
def forward(self, samples: NestedTensor, prompt_type: str = None) -> Dict:
|
| 84 |
+
"""Foward function"""
|
| 85 |
+
self.device = samples.device
|
| 86 |
+
|
| 87 |
+
(
|
| 88 |
+
src_flatten,
|
| 89 |
+
lvl_pos_embed_flatten,
|
| 90 |
+
level_start_index,
|
| 91 |
+
spatial_shapes,
|
| 92 |
+
valid_ratios,
|
| 93 |
+
mask_flatten,
|
| 94 |
+
) = self.forward_backbone_encoder(samples)
|
| 95 |
+
|
| 96 |
+
(
|
| 97 |
+
hs,
|
| 98 |
+
reference,
|
| 99 |
+
ref_dict,
|
| 100 |
+
) = self.transformer(
|
| 101 |
+
src_flatten,
|
| 102 |
+
lvl_pos_embed_flatten,
|
| 103 |
+
level_start_index,
|
| 104 |
+
spatial_shapes,
|
| 105 |
+
valid_ratios,
|
| 106 |
+
mask_flatten,
|
| 107 |
+
prompt_type,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# deformable-detr-line anchor update
|
| 111 |
+
outputs_coord_list = []
|
| 112 |
+
outputs_class = []
|
| 113 |
+
|
| 114 |
+
for layer_idx, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
|
| 115 |
+
zip(reference[:-1], self.bbox_embed, hs)
|
| 116 |
+
):
|
| 117 |
+
layer_delta_unsig = layer_bbox_embed(layer_hs)
|
| 118 |
+
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
|
| 119 |
+
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
|
| 120 |
+
outputs_coord_list.append(layer_outputs_unsig)
|
| 121 |
+
|
| 122 |
+
outputs_coord_list = torch.stack(outputs_coord_list)
|
| 123 |
+
|
| 124 |
+
if ref_dict is None:
|
| 125 |
+
# build a mock outputs_class for mask_dn training
|
| 126 |
+
outputs_class = torch.zeros(
|
| 127 |
+
outputs_coord_list.shape[0],
|
| 128 |
+
outputs_coord_list.shape[1],
|
| 129 |
+
outputs_coord_list.shape[2],
|
| 130 |
+
self.hidden_dim,
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
outputs_class = torch.stack(
|
| 134 |
+
[
|
| 135 |
+
layer_cls_embed(layer_hs, ref_dict)
|
| 136 |
+
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
out = {
|
| 141 |
+
"pred_logits": outputs_class[-1],
|
| 142 |
+
"pred_boxes": outputs_coord_list[-1],
|
| 143 |
+
}
|
| 144 |
+
out["ref_dict"] = ref_dict
|
| 145 |
+
return out
|
| 146 |
+
|
| 147 |
+
def forward_backbone_encoder(self, samples: NestedTensor) -> Tuple:
|
| 148 |
+
# pass through backbone
|
| 149 |
+
if isinstance(samples, (list, torch.Tensor)):
|
| 150 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 151 |
+
features, poss = self.backbone(samples)
|
| 152 |
+
# project features
|
| 153 |
+
srcs = []
|
| 154 |
+
masks = []
|
| 155 |
+
for l, feat in enumerate(features):
|
| 156 |
+
src, mask = feat.decompose()
|
| 157 |
+
srcs.append(self.input_proj[l](src)) # downsample the feature map to 256
|
| 158 |
+
masks.append(mask)
|
| 159 |
+
assert mask is not None
|
| 160 |
+
|
| 161 |
+
if self.num_feature_levels > len(
|
| 162 |
+
srcs
|
| 163 |
+
): # add more feature levels by downsampling the last feature map
|
| 164 |
+
_len_srcs = len(srcs)
|
| 165 |
+
for l in range(_len_srcs, self.num_feature_levels):
|
| 166 |
+
if l == _len_srcs:
|
| 167 |
+
src = self.input_proj[l](features[-1].tensors)
|
| 168 |
+
else:
|
| 169 |
+
src = self.input_proj[l](srcs[-1])
|
| 170 |
+
m = samples.mask
|
| 171 |
+
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
|
| 172 |
+
torch.bool
|
| 173 |
+
)[0]
|
| 174 |
+
pos_l = self.backbone.forward_pos_embed_only(
|
| 175 |
+
NestedTensor(src, mask)
|
| 176 |
+
).to(src.dtype)
|
| 177 |
+
srcs.append(src)
|
| 178 |
+
masks.append(mask)
|
| 179 |
+
poss.append(pos_l)
|
| 180 |
+
|
| 181 |
+
# prepare input for encoder with the following steps:
|
| 182 |
+
# 1. flatten the feature maps and masks
|
| 183 |
+
# 2. Add positional embedding and level embedding
|
| 184 |
+
# 3. Calculate the valid ratio of each feature map based on the mask
|
| 185 |
+
src_flatten = []
|
| 186 |
+
mask_flatten = []
|
| 187 |
+
lvl_pos_embed_flatten = []
|
| 188 |
+
spatial_shapes = []
|
| 189 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, poss)):
|
| 190 |
+
bs, c, h, w = src.shape
|
| 191 |
+
spatial_shape = (h, w)
|
| 192 |
+
spatial_shapes.append(spatial_shape)
|
| 193 |
+
|
| 194 |
+
src = src.flatten(2).transpose(1, 2) # bs, hw, c
|
| 195 |
+
mask = mask.flatten(1) # bs, hw
|
| 196 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
|
| 197 |
+
if self.num_feature_levels > 1 and self.transformer.level_embed is not None:
|
| 198 |
+
lvl_pos_embed = pos_embed + self.transformer.level_embed[lvl].view(
|
| 199 |
+
1, 1, -1
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
lvl_pos_embed = pos_embed
|
| 203 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
| 204 |
+
src_flatten.append(src)
|
| 205 |
+
mask_flatten.append(mask)
|
| 206 |
+
src_flatten = torch.cat(src_flatten, 1)
|
| 207 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
| 208 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
| 209 |
+
spatial_shapes = torch.as_tensor(
|
| 210 |
+
spatial_shapes, dtype=torch.long, device=src_flatten.device
|
| 211 |
+
)
|
| 212 |
+
level_start_index = torch.cat(
|
| 213 |
+
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
|
| 214 |
+
)
|
| 215 |
+
valid_ratios = torch.stack(
|
| 216 |
+
[self.transformer.get_valid_ratio(m) for m in masks], 1
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return (
|
| 220 |
+
src_flatten,
|
| 221 |
+
lvl_pos_embed_flatten,
|
| 222 |
+
level_start_index,
|
| 223 |
+
spatial_shapes,
|
| 224 |
+
valid_ratios,
|
| 225 |
+
mask_flatten,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def prepare_vision_feature_projection_layer(
|
| 229 |
+
self,
|
| 230 |
+
backbone: nn.Module,
|
| 231 |
+
num_feature_levels: int,
|
| 232 |
+
hidden_dim: int,
|
| 233 |
+
two_stage_type: str,
|
| 234 |
+
) -> nn.ModuleList:
|
| 235 |
+
"""Prepare projection layer to map backbone feature to hidden dim.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
backbone (nn.Module): Backbone.
|
| 239 |
+
num_feature_levels (int): Number of feature levels.
|
| 240 |
+
hidden_dim (int): Hidden dim.
|
| 241 |
+
two_stage_type (str): Type of two stage.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
nn.ModuleList: Projection layer.
|
| 245 |
+
"""
|
| 246 |
+
if num_feature_levels > 1:
|
| 247 |
+
num_backbone_outs = len(backbone.num_channels)
|
| 248 |
+
input_proj_list = []
|
| 249 |
+
for _ in range(num_backbone_outs):
|
| 250 |
+
in_channels = backbone.num_channels[_]
|
| 251 |
+
input_proj_list.append(
|
| 252 |
+
nn.Sequential(
|
| 253 |
+
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
|
| 254 |
+
nn.GroupNorm(32, hidden_dim),
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
for _ in range(num_feature_levels - num_backbone_outs):
|
| 258 |
+
input_proj_list.append(
|
| 259 |
+
nn.Sequential(
|
| 260 |
+
nn.Conv2d(
|
| 261 |
+
in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
|
| 262 |
+
),
|
| 263 |
+
nn.GroupNorm(32, hidden_dim),
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
in_channels = hidden_dim
|
| 267 |
+
input_proj = nn.ModuleList(input_proj_list)
|
| 268 |
+
else:
|
| 269 |
+
assert (
|
| 270 |
+
two_stage_type == "no"
|
| 271 |
+
), "two_stage_type should be no if num_feature_levels=1 !!!"
|
| 272 |
+
input_proj = nn.ModuleList(
|
| 273 |
+
[
|
| 274 |
+
nn.Sequential(
|
| 275 |
+
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
|
| 276 |
+
nn.GroupNorm(32, hidden_dim),
|
| 277 |
+
)
|
| 278 |
+
]
|
| 279 |
+
)
|
| 280 |
+
return input_proj
|
| 281 |
+
|
| 282 |
+
def prepare_prediction_head(
|
| 283 |
+
self,
|
| 284 |
+
dec_pred_class_embed_share: bool,
|
| 285 |
+
dec_pred_bbox_embed_share: bool,
|
| 286 |
+
hidden_dim: int,
|
| 287 |
+
num_decoder_layers: int,
|
| 288 |
+
) -> Union[nn.ModuleList, nn.ModuleList]:
|
| 289 |
+
"""Prepare prediction head. Including class embed and bbox embed.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
dec_pred_class_embed_share (bool): Whether to share class embed for all decoder layers.
|
| 293 |
+
dec_pred_bbox_embed_share (bool): Whether to share bbox embed for all decoder layers.
|
| 294 |
+
im (int): Hidden dim.
|
| 295 |
+
num_decoder_layers (int): Number of decoder layers.
|
| 296 |
+
|
| 297 |
+
"""
|
| 298 |
+
_class_embed = ContrastiveAssign()
|
| 299 |
+
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
| 300 |
+
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
|
| 301 |
+
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
|
| 302 |
+
if dec_pred_bbox_embed_share:
|
| 303 |
+
box_embed_layerlist = [_bbox_embed for _ in range(num_decoder_layers)]
|
| 304 |
+
else:
|
| 305 |
+
box_embed_layerlist = [
|
| 306 |
+
copy.deepcopy(_bbox_embed) for i in range(num_decoder_layers)
|
| 307 |
+
]
|
| 308 |
+
if dec_pred_class_embed_share:
|
| 309 |
+
class_embed_layerlist = [_class_embed for i in range(num_decoder_layers)]
|
| 310 |
+
else:
|
| 311 |
+
class_embed_layerlist = [
|
| 312 |
+
copy.deepcopy(_class_embed) for i in range(num_decoder_layers)
|
| 313 |
+
]
|
| 314 |
+
bbox_embed = nn.ModuleList(box_embed_layerlist)
|
| 315 |
+
class_embed = nn.ModuleList(class_embed_layerlist)
|
| 316 |
+
self.bbox_embed = bbox_embed
|
| 317 |
+
self.class_embed = class_embed
|
| 318 |
+
|
| 319 |
+
# iniitalize bbox embed and class embed in transformer
|
| 320 |
+
self.transformer.decoder.bbox_embed = bbox_embed
|
| 321 |
+
self.transformer.decoder.class_embed = class_embed
|
| 322 |
+
|
| 323 |
+
if self.transformer.two_stage_type != "no":
|
| 324 |
+
if self.transformer.two_stage_bbox_embed_share:
|
| 325 |
+
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
|
| 326 |
+
self.transformer.enc_out_bbox_embed = _bbox_embed
|
| 327 |
+
else:
|
| 328 |
+
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
|
| 329 |
+
|
| 330 |
+
if self.transformer.two_stage_class_embed_share:
|
| 331 |
+
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
|
| 332 |
+
self.transformer.enc_out_class_embed = _class_embed
|
| 333 |
+
else:
|
| 334 |
+
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
|
| 335 |
+
|
| 336 |
+
self.refpoint_embed = None
|
| 337 |
+
|
| 338 |
+
def _reset_parameters(self):
|
| 339 |
+
# init input_proj
|
| 340 |
+
for proj in self.input_proj:
|
| 341 |
+
|
| 342 |
+
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
| 343 |
+
nn.init.constant_(proj[0].bias, 0)
|
detect_tools/upn/models/backbone/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .swin import SwinTransformer
|
| 2 |
+
from .wrapper import SwinWrapper
|
| 3 |
+
|
| 4 |
+
__all__ = ["SwinWrapper", "SwinTransformer"]
|
detect_tools/upn/models/backbone/swin.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.checkpoint as checkpoint
|
| 8 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 9 |
+
|
| 10 |
+
from detect_tools.upn import BACKBONES
|
| 11 |
+
from detect_tools.upn.models.module import NestedTensor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Mlp(nn.Module):
|
| 15 |
+
"""Multilayer perceptron."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features,
|
| 20 |
+
hidden_features=None,
|
| 21 |
+
out_features=None,
|
| 22 |
+
act_layer=nn.GELU,
|
| 23 |
+
drop=0.0,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
out_features = out_features or in_features
|
| 27 |
+
hidden_features = hidden_features or in_features
|
| 28 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 29 |
+
self.act = act_layer()
|
| 30 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 31 |
+
self.drop = nn.Dropout(drop)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = self.fc1(x)
|
| 35 |
+
x = self.act(x)
|
| 36 |
+
x = self.drop(x)
|
| 37 |
+
x = self.fc2(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def window_partition(x, window_size):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
x: (B, H, W, C)
|
| 46 |
+
window_size (int): window size
|
| 47 |
+
Returns:
|
| 48 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 49 |
+
"""
|
| 50 |
+
B, H, W, C = x.shape
|
| 51 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 52 |
+
windows = (
|
| 53 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 54 |
+
)
|
| 55 |
+
return windows
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def window_reverse(windows, window_size, H, W):
|
| 59 |
+
"""
|
| 60 |
+
Args:
|
| 61 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 62 |
+
window_size (int): Window size
|
| 63 |
+
H (int): Height of image
|
| 64 |
+
W (int): Width of image
|
| 65 |
+
Returns:
|
| 66 |
+
x: (B, H, W, C)
|
| 67 |
+
"""
|
| 68 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 69 |
+
x = windows.view(
|
| 70 |
+
B, H // window_size, W // window_size, window_size, window_size, -1
|
| 71 |
+
)
|
| 72 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class WindowAttention(nn.Module):
|
| 77 |
+
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 78 |
+
It supports both of shifted and non-shifted window.
|
| 79 |
+
Args:
|
| 80 |
+
dim (int): Number of input channels.
|
| 81 |
+
window_size (tuple[int]): The height and width of the window.
|
| 82 |
+
num_heads (int): Number of attention heads.
|
| 83 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 84 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 85 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 86 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
dim,
|
| 92 |
+
window_size,
|
| 93 |
+
num_heads,
|
| 94 |
+
qkv_bias=True,
|
| 95 |
+
qk_scale=None,
|
| 96 |
+
attn_drop=0.0,
|
| 97 |
+
proj_drop=0.0,
|
| 98 |
+
):
|
| 99 |
+
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.dim = dim
|
| 102 |
+
self.window_size = window_size # Wh, Ww
|
| 103 |
+
self.num_heads = num_heads
|
| 104 |
+
head_dim = dim // num_heads
|
| 105 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 106 |
+
|
| 107 |
+
# define a parameter table of relative position bias
|
| 108 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 109 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
| 110 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
| 111 |
+
|
| 112 |
+
# get pair-wise relative position index for each token inside the window
|
| 113 |
+
coords_h = torch.arange(self.window_size[0])
|
| 114 |
+
coords_w = torch.arange(self.window_size[1])
|
| 115 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 116 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 117 |
+
relative_coords = (
|
| 118 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 119 |
+
) # 2, Wh*Ww, Wh*Ww
|
| 120 |
+
relative_coords = relative_coords.permute(
|
| 121 |
+
1, 2, 0
|
| 122 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 123 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 124 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 125 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 126 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 127 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 128 |
+
|
| 129 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 131 |
+
self.proj = nn.Linear(dim, dim)
|
| 132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 133 |
+
|
| 134 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
| 135 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 136 |
+
|
| 137 |
+
def forward(self, x, mask=None):
|
| 138 |
+
"""Forward function.
|
| 139 |
+
Args:
|
| 140 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 141 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 142 |
+
"""
|
| 143 |
+
B_, N, C = x.shape
|
| 144 |
+
qkv = (
|
| 145 |
+
self.qkv(x)
|
| 146 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
| 147 |
+
.permute(2, 0, 3, 1, 4)
|
| 148 |
+
)
|
| 149 |
+
q, k, v = (
|
| 150 |
+
qkv[0],
|
| 151 |
+
qkv[1],
|
| 152 |
+
qkv[2],
|
| 153 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
| 154 |
+
|
| 155 |
+
q = q * self.scale
|
| 156 |
+
attn = q @ k.transpose(-2, -1)
|
| 157 |
+
|
| 158 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 159 |
+
self.relative_position_index.view(-1)
|
| 160 |
+
].view(
|
| 161 |
+
self.window_size[0] * self.window_size[1],
|
| 162 |
+
self.window_size[0] * self.window_size[1],
|
| 163 |
+
-1,
|
| 164 |
+
) # Wh*Ww,Wh*Ww,nH
|
| 165 |
+
relative_position_bias = relative_position_bias.permute(
|
| 166 |
+
2, 0, 1
|
| 167 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 168 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 169 |
+
|
| 170 |
+
if mask is not None:
|
| 171 |
+
nW = mask.shape[0]
|
| 172 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
| 173 |
+
1
|
| 174 |
+
).unsqueeze(0)
|
| 175 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 176 |
+
attn = self.softmax(attn)
|
| 177 |
+
else:
|
| 178 |
+
attn = self.softmax(attn)
|
| 179 |
+
|
| 180 |
+
attn = self.attn_drop(attn)
|
| 181 |
+
|
| 182 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 183 |
+
x = self.proj(x)
|
| 184 |
+
x = self.proj_drop(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class SwinTransformerBlock(nn.Module):
|
| 189 |
+
"""Swin Transformer Block.
|
| 190 |
+
Args:
|
| 191 |
+
dim (int): Number of input channels.
|
| 192 |
+
num_heads (int): Number of attention heads.
|
| 193 |
+
window_size (int): Window size.
|
| 194 |
+
shift_size (int): Shift size for SW-MSA.
|
| 195 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 196 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 197 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 198 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 199 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 200 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 201 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 202 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
dim,
|
| 208 |
+
num_heads,
|
| 209 |
+
window_size=7,
|
| 210 |
+
shift_size=0,
|
| 211 |
+
mlp_ratio=4.0,
|
| 212 |
+
qkv_bias=True,
|
| 213 |
+
qk_scale=None,
|
| 214 |
+
drop=0.0,
|
| 215 |
+
attn_drop=0.0,
|
| 216 |
+
drop_path=0.0,
|
| 217 |
+
act_layer=nn.GELU,
|
| 218 |
+
norm_layer=nn.LayerNorm,
|
| 219 |
+
):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.dim = dim
|
| 222 |
+
self.num_heads = num_heads
|
| 223 |
+
self.window_size = window_size
|
| 224 |
+
self.shift_size = shift_size
|
| 225 |
+
self.mlp_ratio = mlp_ratio
|
| 226 |
+
assert (
|
| 227 |
+
0 <= self.shift_size < self.window_size
|
| 228 |
+
), "shift_size must in 0-window_size"
|
| 229 |
+
|
| 230 |
+
self.norm1 = norm_layer(dim)
|
| 231 |
+
self.attn = WindowAttention(
|
| 232 |
+
dim,
|
| 233 |
+
window_size=to_2tuple(self.window_size),
|
| 234 |
+
num_heads=num_heads,
|
| 235 |
+
qkv_bias=qkv_bias,
|
| 236 |
+
qk_scale=qk_scale,
|
| 237 |
+
attn_drop=attn_drop,
|
| 238 |
+
proj_drop=drop,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 242 |
+
self.norm2 = norm_layer(dim)
|
| 243 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 244 |
+
self.mlp = Mlp(
|
| 245 |
+
in_features=dim,
|
| 246 |
+
hidden_features=mlp_hidden_dim,
|
| 247 |
+
act_layer=act_layer,
|
| 248 |
+
drop=drop,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
self.H = None
|
| 252 |
+
self.W = None
|
| 253 |
+
|
| 254 |
+
def forward(self, x, mask_matrix):
|
| 255 |
+
"""Forward function.
|
| 256 |
+
Args:
|
| 257 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 258 |
+
H, W: Spatial resolution of the input feature.
|
| 259 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 260 |
+
"""
|
| 261 |
+
B, L, C = x.shape
|
| 262 |
+
H, W = self.H, self.W
|
| 263 |
+
assert L == H * W, "input feature has wrong size"
|
| 264 |
+
|
| 265 |
+
shortcut = x
|
| 266 |
+
x = self.norm1(x)
|
| 267 |
+
x = x.view(B, H, W, C)
|
| 268 |
+
|
| 269 |
+
# pad feature maps to multiples of window size
|
| 270 |
+
pad_l = pad_t = 0
|
| 271 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 272 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 273 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 274 |
+
_, Hp, Wp, _ = x.shape
|
| 275 |
+
|
| 276 |
+
# cyclic shift
|
| 277 |
+
if self.shift_size > 0:
|
| 278 |
+
shifted_x = torch.roll(
|
| 279 |
+
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
| 280 |
+
)
|
| 281 |
+
attn_mask = mask_matrix
|
| 282 |
+
else:
|
| 283 |
+
shifted_x = x
|
| 284 |
+
attn_mask = None
|
| 285 |
+
|
| 286 |
+
# partition windows
|
| 287 |
+
x_windows = window_partition(
|
| 288 |
+
shifted_x, self.window_size
|
| 289 |
+
) # nW*B, window_size, window_size, C
|
| 290 |
+
x_windows = x_windows.view(
|
| 291 |
+
-1, self.window_size * self.window_size, C
|
| 292 |
+
) # nW*B, window_size*window_size, C
|
| 293 |
+
|
| 294 |
+
# W-MSA/SW-MSA
|
| 295 |
+
attn_windows = self.attn(
|
| 296 |
+
x_windows, mask=attn_mask
|
| 297 |
+
) # nW*B, window_size*window_size, C
|
| 298 |
+
|
| 299 |
+
# merge windows
|
| 300 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 301 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
| 302 |
+
|
| 303 |
+
# reverse cyclic shift
|
| 304 |
+
if self.shift_size > 0:
|
| 305 |
+
x = torch.roll(
|
| 306 |
+
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
x = shifted_x
|
| 310 |
+
|
| 311 |
+
if pad_r > 0 or pad_b > 0:
|
| 312 |
+
x = x[:, :H, :W, :].contiguous()
|
| 313 |
+
|
| 314 |
+
x = x.view(B, H * W, C)
|
| 315 |
+
|
| 316 |
+
# FFN
|
| 317 |
+
x = shortcut + self.drop_path(x)
|
| 318 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 319 |
+
|
| 320 |
+
return x
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class PatchMerging(nn.Module):
|
| 324 |
+
"""Patch Merging Layer
|
| 325 |
+
Args:
|
| 326 |
+
dim (int): Number of input channels.
|
| 327 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.dim = dim
|
| 333 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 334 |
+
self.norm = norm_layer(4 * dim)
|
| 335 |
+
|
| 336 |
+
def forward(self, x, H, W):
|
| 337 |
+
"""Forward function.
|
| 338 |
+
Args:
|
| 339 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 340 |
+
H, W: Spatial resolution of the input feature.
|
| 341 |
+
"""
|
| 342 |
+
B, L, C = x.shape
|
| 343 |
+
assert L == H * W, "input feature has wrong size"
|
| 344 |
+
|
| 345 |
+
x = x.view(B, H, W, C)
|
| 346 |
+
|
| 347 |
+
# padding
|
| 348 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 349 |
+
if pad_input:
|
| 350 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 351 |
+
|
| 352 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 353 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 354 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 355 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 356 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 357 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 358 |
+
|
| 359 |
+
x = self.norm(x)
|
| 360 |
+
x = self.reduction(x)
|
| 361 |
+
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class BasicLayer(nn.Module):
|
| 366 |
+
"""A basic Swin Transformer layer for one stage.
|
| 367 |
+
Args:
|
| 368 |
+
dim (int): Number of feature channels
|
| 369 |
+
depth (int): Depths of this stage.
|
| 370 |
+
num_heads (int): Number of attention head.
|
| 371 |
+
window_size (int): Local window size. Default: 7.
|
| 372 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 373 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 374 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 375 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 376 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 377 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 378 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 379 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 380 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
dim,
|
| 386 |
+
depth,
|
| 387 |
+
num_heads,
|
| 388 |
+
window_size=7,
|
| 389 |
+
mlp_ratio=4.0,
|
| 390 |
+
qkv_bias=True,
|
| 391 |
+
qk_scale=None,
|
| 392 |
+
drop=0.0,
|
| 393 |
+
attn_drop=0.0,
|
| 394 |
+
drop_path=0.0,
|
| 395 |
+
norm_layer=nn.LayerNorm,
|
| 396 |
+
downsample=None,
|
| 397 |
+
use_checkpoint=False,
|
| 398 |
+
):
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.window_size = window_size
|
| 401 |
+
self.shift_size = window_size // 2
|
| 402 |
+
self.depth = depth
|
| 403 |
+
self.use_checkpoint = use_checkpoint
|
| 404 |
+
|
| 405 |
+
# build blocks
|
| 406 |
+
self.blocks = nn.ModuleList(
|
| 407 |
+
[
|
| 408 |
+
SwinTransformerBlock(
|
| 409 |
+
dim=dim,
|
| 410 |
+
num_heads=num_heads,
|
| 411 |
+
window_size=window_size,
|
| 412 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 413 |
+
mlp_ratio=mlp_ratio,
|
| 414 |
+
qkv_bias=qkv_bias,
|
| 415 |
+
qk_scale=qk_scale,
|
| 416 |
+
drop=drop,
|
| 417 |
+
attn_drop=attn_drop,
|
| 418 |
+
drop_path=(
|
| 419 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path
|
| 420 |
+
),
|
| 421 |
+
norm_layer=norm_layer,
|
| 422 |
+
)
|
| 423 |
+
for i in range(depth)
|
| 424 |
+
]
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# patch merging layer
|
| 428 |
+
if downsample is not None:
|
| 429 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 430 |
+
else:
|
| 431 |
+
self.downsample = None
|
| 432 |
+
|
| 433 |
+
def forward(self, x, H, W):
|
| 434 |
+
"""Forward function.
|
| 435 |
+
Args:
|
| 436 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 437 |
+
H, W: Spatial resolution of the input feature.
|
| 438 |
+
"""
|
| 439 |
+
|
| 440 |
+
# calculate attention mask for SW-MSA
|
| 441 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 442 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 443 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 444 |
+
h_slices = (
|
| 445 |
+
slice(0, -self.window_size),
|
| 446 |
+
slice(-self.window_size, -self.shift_size),
|
| 447 |
+
slice(-self.shift_size, None),
|
| 448 |
+
)
|
| 449 |
+
w_slices = (
|
| 450 |
+
slice(0, -self.window_size),
|
| 451 |
+
slice(-self.window_size, -self.shift_size),
|
| 452 |
+
slice(-self.shift_size, None),
|
| 453 |
+
)
|
| 454 |
+
cnt = 0
|
| 455 |
+
for h in h_slices:
|
| 456 |
+
for w in w_slices:
|
| 457 |
+
img_mask[:, h, w, :] = cnt
|
| 458 |
+
cnt += 1
|
| 459 |
+
|
| 460 |
+
mask_windows = window_partition(
|
| 461 |
+
img_mask, self.window_size
|
| 462 |
+
) # nW, window_size, window_size, 1
|
| 463 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 464 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 465 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
| 466 |
+
attn_mask == 0, float(0.0)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
for blk in self.blocks:
|
| 470 |
+
blk.H, blk.W = H, W
|
| 471 |
+
if self.use_checkpoint:
|
| 472 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 473 |
+
else:
|
| 474 |
+
x = blk(x, attn_mask)
|
| 475 |
+
if self.downsample is not None:
|
| 476 |
+
x_down = self.downsample(x, H, W)
|
| 477 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 478 |
+
return x, H, W, x_down, Wh, Ww
|
| 479 |
+
else:
|
| 480 |
+
return x, H, W, x, H, W
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class PatchEmbed(nn.Module):
|
| 484 |
+
"""Image to Patch Embedding
|
| 485 |
+
Args:
|
| 486 |
+
patch_size (int): Patch token size. Default: 4.
|
| 487 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 488 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 489 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 493 |
+
super().__init__()
|
| 494 |
+
patch_size = to_2tuple(patch_size)
|
| 495 |
+
self.patch_size = patch_size
|
| 496 |
+
|
| 497 |
+
self.in_chans = in_chans
|
| 498 |
+
self.embed_dim = embed_dim
|
| 499 |
+
|
| 500 |
+
self.proj = nn.Conv2d(
|
| 501 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
| 502 |
+
)
|
| 503 |
+
if norm_layer is not None:
|
| 504 |
+
self.norm = norm_layer(embed_dim)
|
| 505 |
+
else:
|
| 506 |
+
self.norm = None
|
| 507 |
+
|
| 508 |
+
def forward(self, x):
|
| 509 |
+
"""Forward function."""
|
| 510 |
+
# padding
|
| 511 |
+
_, _, H, W = x.size()
|
| 512 |
+
if W % self.patch_size[1] != 0:
|
| 513 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 514 |
+
if H % self.patch_size[0] != 0:
|
| 515 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 516 |
+
|
| 517 |
+
x = self.proj(x) # B C Wh Ww
|
| 518 |
+
if self.norm is not None:
|
| 519 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 520 |
+
x = x.flatten(2).transpose(1, 2)
|
| 521 |
+
x = self.norm(x)
|
| 522 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 523 |
+
|
| 524 |
+
return x
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
@BACKBONES.register_module()
|
| 528 |
+
class SwinTransformer(nn.Module):
|
| 529 |
+
"""Swin Transformer backbone.
|
| 530 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 531 |
+
https://arxiv.org/pdf/2103.14030
|
| 532 |
+
Args:
|
| 533 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 534 |
+
used in absolute postion embedding. Default 224.
|
| 535 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 536 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 537 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 538 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 539 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 540 |
+
window_size (int): Window size. Default: 7.
|
| 541 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 542 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 543 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 544 |
+
drop_rate (float): Dropout rate.
|
| 545 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 546 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 547 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 548 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 549 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 550 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 551 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 552 |
+
-1 means not freezing any parameters.
|
| 553 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 554 |
+
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(
|
| 558 |
+
self,
|
| 559 |
+
pretrain_img_size=224,
|
| 560 |
+
patch_size=4,
|
| 561 |
+
in_chans=3,
|
| 562 |
+
embed_dim=96,
|
| 563 |
+
depths=[2, 2, 6, 2],
|
| 564 |
+
num_heads=[3, 6, 12, 24],
|
| 565 |
+
window_size=7,
|
| 566 |
+
mlp_ratio=4.0,
|
| 567 |
+
qkv_bias=True,
|
| 568 |
+
qk_scale=None,
|
| 569 |
+
drop_rate=0.0,
|
| 570 |
+
attn_drop_rate=0.0,
|
| 571 |
+
drop_path_rate=0.2,
|
| 572 |
+
norm_layer=nn.LayerNorm,
|
| 573 |
+
ape=False,
|
| 574 |
+
patch_norm=True,
|
| 575 |
+
out_indices=(0, 1, 2, 3),
|
| 576 |
+
frozen_stages=-1,
|
| 577 |
+
dilation=False,
|
| 578 |
+
use_checkpoint=False,
|
| 579 |
+
):
|
| 580 |
+
super().__init__()
|
| 581 |
+
|
| 582 |
+
self.pretrain_img_size = pretrain_img_size
|
| 583 |
+
self.num_layers = len(depths)
|
| 584 |
+
self.embed_dim = embed_dim
|
| 585 |
+
self.ape = ape
|
| 586 |
+
self.patch_norm = patch_norm
|
| 587 |
+
self.out_indices = out_indices
|
| 588 |
+
self.frozen_stages = frozen_stages
|
| 589 |
+
self.dilation = dilation
|
| 590 |
+
|
| 591 |
+
if use_checkpoint:
|
| 592 |
+
print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 593 |
+
|
| 594 |
+
# split image into non-overlapping patches
|
| 595 |
+
self.patch_embed = PatchEmbed(
|
| 596 |
+
patch_size=patch_size,
|
| 597 |
+
in_chans=in_chans,
|
| 598 |
+
embed_dim=embed_dim,
|
| 599 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# absolute position embedding
|
| 603 |
+
if self.ape:
|
| 604 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 605 |
+
patch_size = to_2tuple(patch_size)
|
| 606 |
+
patches_resolution = [
|
| 607 |
+
pretrain_img_size[0] // patch_size[0],
|
| 608 |
+
pretrain_img_size[1] // patch_size[1],
|
| 609 |
+
]
|
| 610 |
+
|
| 611 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 612 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
| 613 |
+
)
|
| 614 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
| 615 |
+
|
| 616 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 617 |
+
|
| 618 |
+
# stochastic depth
|
| 619 |
+
dpr = [
|
| 620 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 621 |
+
] # stochastic depth decay rule
|
| 622 |
+
|
| 623 |
+
# build layers
|
| 624 |
+
self.layers = nn.ModuleList()
|
| 625 |
+
# prepare downsample list
|
| 626 |
+
downsamplelist = [PatchMerging for i in range(self.num_layers)]
|
| 627 |
+
downsamplelist[-1] = None
|
| 628 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 629 |
+
if self.dilation:
|
| 630 |
+
downsamplelist[-2] = None
|
| 631 |
+
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
|
| 632 |
+
for i_layer in range(self.num_layers):
|
| 633 |
+
layer = BasicLayer(
|
| 634 |
+
# dim=int(embed_dim * 2 ** i_layer),
|
| 635 |
+
dim=num_features[i_layer],
|
| 636 |
+
depth=depths[i_layer],
|
| 637 |
+
num_heads=num_heads[i_layer],
|
| 638 |
+
window_size=window_size,
|
| 639 |
+
mlp_ratio=mlp_ratio,
|
| 640 |
+
qkv_bias=qkv_bias,
|
| 641 |
+
qk_scale=qk_scale,
|
| 642 |
+
drop=drop_rate,
|
| 643 |
+
attn_drop=attn_drop_rate,
|
| 644 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
| 645 |
+
norm_layer=norm_layer,
|
| 646 |
+
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 647 |
+
downsample=downsamplelist[i_layer],
|
| 648 |
+
use_checkpoint=use_checkpoint,
|
| 649 |
+
)
|
| 650 |
+
self.layers.append(layer)
|
| 651 |
+
|
| 652 |
+
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
| 653 |
+
self.num_features = num_features
|
| 654 |
+
|
| 655 |
+
# add a norm layer for each output
|
| 656 |
+
for i_layer in out_indices:
|
| 657 |
+
layer = norm_layer(num_features[i_layer])
|
| 658 |
+
layer_name = f"norm{i_layer}"
|
| 659 |
+
self.add_module(layer_name, layer)
|
| 660 |
+
|
| 661 |
+
self._freeze_stages()
|
| 662 |
+
|
| 663 |
+
def _freeze_stages(self):
|
| 664 |
+
if self.frozen_stages >= 0:
|
| 665 |
+
self.patch_embed.eval()
|
| 666 |
+
for param in self.patch_embed.parameters():
|
| 667 |
+
param.requires_grad = False
|
| 668 |
+
|
| 669 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 670 |
+
self.absolute_pos_embed.requires_grad = False
|
| 671 |
+
|
| 672 |
+
if self.frozen_stages >= 2:
|
| 673 |
+
self.pos_drop.eval()
|
| 674 |
+
for i in range(0, self.frozen_stages - 1):
|
| 675 |
+
m = self.layers[i]
|
| 676 |
+
m.eval()
|
| 677 |
+
for param in m.parameters():
|
| 678 |
+
param.requires_grad = False
|
| 679 |
+
|
| 680 |
+
def forward_raw(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 681 |
+
"""Forward function."""
|
| 682 |
+
x = self.patch_embed(x)
|
| 683 |
+
|
| 684 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 685 |
+
if self.ape:
|
| 686 |
+
# interpolate the position embedding to the corresponding size
|
| 687 |
+
absolute_pos_embed = F.interpolate(
|
| 688 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
| 689 |
+
)
|
| 690 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
| 691 |
+
else:
|
| 692 |
+
x = x.flatten(2).transpose(1, 2)
|
| 693 |
+
x = self.pos_drop(x)
|
| 694 |
+
|
| 695 |
+
outs = []
|
| 696 |
+
for i in range(self.num_layers):
|
| 697 |
+
layer = self.layers[i]
|
| 698 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 699 |
+
# import ipdb; ipdb.set_trace()
|
| 700 |
+
|
| 701 |
+
if i in self.out_indices:
|
| 702 |
+
norm_layer = getattr(self, f"norm{i}")
|
| 703 |
+
x_out = norm_layer(x_out)
|
| 704 |
+
|
| 705 |
+
out = (
|
| 706 |
+
x_out.view(-1, H, W, self.num_features[i])
|
| 707 |
+
.permute(0, 3, 1, 2)
|
| 708 |
+
.contiguous()
|
| 709 |
+
)
|
| 710 |
+
outs.append(out)
|
| 711 |
+
|
| 712 |
+
return tuple(outs)
|
| 713 |
+
|
| 714 |
+
def forward(self, tensor_list: NestedTensor) -> Dict:
|
| 715 |
+
"""Forward function.
|
| 716 |
+
|
| 717 |
+
Args:
|
| 718 |
+
tensor_list (NestedTensor): NestedTensor object containing tensors and masks.
|
| 719 |
+
|
| 720 |
+
Returns:
|
| 721 |
+
Dict: Dict containing output tensors. The structure is as follows.
|
| 722 |
+
- 0: NestedTensor from stage 0.
|
| 723 |
+
- 1: NestedTensor from stage 1.
|
| 724 |
+
- 2: NestedTensor from stage 2.
|
| 725 |
+
- 3: NestedTensor from stage 3.
|
| 726 |
+
"""
|
| 727 |
+
x = tensor_list.tensors
|
| 728 |
+
|
| 729 |
+
x = self.patch_embed(x)
|
| 730 |
+
|
| 731 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 732 |
+
if self.ape:
|
| 733 |
+
# interpolate the position embedding to the corresponding size
|
| 734 |
+
absolute_pos_embed = F.interpolate(
|
| 735 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
| 736 |
+
)
|
| 737 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
| 738 |
+
else:
|
| 739 |
+
x = x.flatten(2).transpose(1, 2)
|
| 740 |
+
x = self.pos_drop(x)
|
| 741 |
+
|
| 742 |
+
outs = []
|
| 743 |
+
for i in range(self.num_layers):
|
| 744 |
+
layer = self.layers[i]
|
| 745 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 746 |
+
|
| 747 |
+
if i in self.out_indices:
|
| 748 |
+
norm_layer = getattr(self, f"norm{i}")
|
| 749 |
+
x_out = norm_layer(x_out)
|
| 750 |
+
|
| 751 |
+
out = (
|
| 752 |
+
x_out.view(-1, H, W, self.num_features[i])
|
| 753 |
+
.permute(0, 3, 1, 2)
|
| 754 |
+
.contiguous()
|
| 755 |
+
)
|
| 756 |
+
outs.append(out)
|
| 757 |
+
|
| 758 |
+
# collect for nesttensors
|
| 759 |
+
outs_dict = {}
|
| 760 |
+
for idx, out_i in enumerate(outs):
|
| 761 |
+
m = tensor_list.mask
|
| 762 |
+
assert m is not None
|
| 763 |
+
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[
|
| 764 |
+
0
|
| 765 |
+
]
|
| 766 |
+
outs_dict[idx] = NestedTensor(out_i, mask)
|
| 767 |
+
|
| 768 |
+
return outs_dict
|
| 769 |
+
|
| 770 |
+
def train(self, mode=True):
|
| 771 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 772 |
+
super(SwinTransformer, self).train(mode)
|
| 773 |
+
self._freeze_stages()
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
| 777 |
+
assert modelname in [
|
| 778 |
+
"swin_T_224_1k",
|
| 779 |
+
"swin_B_224_22k",
|
| 780 |
+
"swin_B_384_22k",
|
| 781 |
+
"swin_L_224_22k",
|
| 782 |
+
"swin_L_384_22k",
|
| 783 |
+
]
|
| 784 |
+
|
| 785 |
+
model_para_dict = {
|
| 786 |
+
"swin_T_224_1k": dict(
|
| 787 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
|
| 788 |
+
),
|
| 789 |
+
"swin_B_224_22k": dict(
|
| 790 |
+
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
|
| 791 |
+
),
|
| 792 |
+
"swin_B_384_22k": dict(
|
| 793 |
+
embed_dim=128,
|
| 794 |
+
depths=[2, 2, 18, 2],
|
| 795 |
+
num_heads=[4, 8, 16, 32],
|
| 796 |
+
window_size=12,
|
| 797 |
+
),
|
| 798 |
+
"swin_L_224_22k": dict(
|
| 799 |
+
embed_dim=192,
|
| 800 |
+
depths=[2, 2, 18, 2],
|
| 801 |
+
num_heads=[6, 12, 24, 48],
|
| 802 |
+
window_size=7,
|
| 803 |
+
),
|
| 804 |
+
"swin_L_384_22k": dict(
|
| 805 |
+
embed_dim=192,
|
| 806 |
+
depths=[2, 2, 18, 2],
|
| 807 |
+
num_heads=[6, 12, 24, 48],
|
| 808 |
+
window_size=12,
|
| 809 |
+
),
|
| 810 |
+
}
|
| 811 |
+
kw_cgf = model_para_dict[modelname]
|
| 812 |
+
kw_cgf.update(kw)
|
| 813 |
+
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
|
| 814 |
+
return model
|
detect_tools/upn/models/backbone/wrapper.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding
|
| 7 |
+
from detect_tools.upn.models.module import NestedTensor
|
| 8 |
+
from detect_tools.upn.models.utils import clean_state_dict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 14 |
+
|
| 15 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
| 16 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
| 17 |
+
produce nans.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, n):
|
| 21 |
+
super(FrozenBatchNorm2d, self).__init__()
|
| 22 |
+
self.register_buffer("weight", torch.ones(n))
|
| 23 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 24 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 25 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 26 |
+
|
| 27 |
+
def _load_from_state_dict(
|
| 28 |
+
self,
|
| 29 |
+
state_dict,
|
| 30 |
+
prefix,
|
| 31 |
+
local_metadata,
|
| 32 |
+
strict,
|
| 33 |
+
missing_keys,
|
| 34 |
+
unexpected_keys,
|
| 35 |
+
error_msgs,
|
| 36 |
+
):
|
| 37 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 38 |
+
if num_batches_tracked_key in state_dict:
|
| 39 |
+
del state_dict[num_batches_tracked_key]
|
| 40 |
+
|
| 41 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
| 42 |
+
state_dict,
|
| 43 |
+
prefix,
|
| 44 |
+
local_metadata,
|
| 45 |
+
strict,
|
| 46 |
+
missing_keys,
|
| 47 |
+
unexpected_keys,
|
| 48 |
+
error_msgs,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
# move reshapes to the beginning
|
| 53 |
+
# to make it fuser-friendly
|
| 54 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
| 55 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
| 56 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
| 57 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
| 58 |
+
eps = 1e-5
|
| 59 |
+
scale = w * (rv + eps).rsqrt()
|
| 60 |
+
bias = b - rm * scale
|
| 61 |
+
return x * scale + bias
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Joiner(nn.Module):
|
| 65 |
+
"""A wrapper for the backbone and the position embedding.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
backbone_cfg (Dict): Config dict to build backbone.
|
| 69 |
+
position_embedding_cfg (Dict): Config dict to build position embedding.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None:
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.backbone = backbone
|
| 75 |
+
self.pos_embed = position_embedding
|
| 76 |
+
|
| 77 |
+
def forward(
|
| 78 |
+
self, tensor_list: NestedTensor
|
| 79 |
+
) -> Union[List[NestedTensor], List[torch.Tensor]]:
|
| 80 |
+
"""Forward function.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
[List[NestedTensor]: A list of feature map in NestedTensor format.
|
| 87 |
+
List[torch.Tensor]: A list of position encoding.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
xs = self.backbone(tensor_list)
|
| 91 |
+
out: List[NestedTensor] = []
|
| 92 |
+
pos = []
|
| 93 |
+
for layer_idx, x in xs.items():
|
| 94 |
+
out.append(x)
|
| 95 |
+
# position encoding
|
| 96 |
+
pos.append(self.pos_embed(x).to(x.tensors.dtype))
|
| 97 |
+
|
| 98 |
+
return out, pos
|
| 99 |
+
|
| 100 |
+
def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor:
|
| 101 |
+
"""Forward function for position embedding only. This is used to generate additional layer
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
x (NestedTensor): NestedTensor wrapping the input tensor.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
[List[torch.Tensor]: A list of position encoding.
|
| 108 |
+
"""
|
| 109 |
+
return self.pos_embed(x)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@BACKBONES.register_module()
|
| 113 |
+
class SwinWrapper(nn.Module):
|
| 114 |
+
"""A wrapper for swin transformer.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we
|
| 118 |
+
will call `get_swin_config` to get the config dict.
|
| 119 |
+
dilation (bool): Whether to use dilation in stage 4.
|
| 120 |
+
position_embedding_cfg (Dict): Config dict to build position embedding.
|
| 121 |
+
lr_backbone (float): Learning rate of the backbone.
|
| 122 |
+
return_interm_layers (List[int]): Which layers to return.
|
| 123 |
+
backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone.
|
| 124 |
+
use_checkpoint (bool): Whether to use checkpoint. Default: False.
|
| 125 |
+
ckpt_path (str): Checkpoint path. Default: None.
|
| 126 |
+
use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
backbone_cfg: Union[Dict, str],
|
| 132 |
+
dilation: bool,
|
| 133 |
+
position_embedding_cfg: Dict,
|
| 134 |
+
lr_backbone: float,
|
| 135 |
+
return_interm_indices: List[int],
|
| 136 |
+
backbone_freeze_keywords: List[str],
|
| 137 |
+
use_checkpoint: bool = False,
|
| 138 |
+
backbone_ckpt_path: str = None,
|
| 139 |
+
) -> None:
|
| 140 |
+
super(SwinWrapper, self).__init__()
|
| 141 |
+
pos_embedding = build_position_embedding(position_embedding_cfg)
|
| 142 |
+
train_backbone = lr_backbone > 0
|
| 143 |
+
if not train_backbone:
|
| 144 |
+
raise ValueError("Please set lr_backbone > 0")
|
| 145 |
+
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
| 146 |
+
|
| 147 |
+
# build backbone
|
| 148 |
+
if isinstance(backbone_cfg, str):
|
| 149 |
+
assert (
|
| 150 |
+
backbone_cfg
|
| 151 |
+
in backbone_cfg
|
| 152 |
+
in [
|
| 153 |
+
"swin_T_224_1k",
|
| 154 |
+
"swin_B_224_22k",
|
| 155 |
+
"swin_B_384_22k",
|
| 156 |
+
"swin_L_224_22k",
|
| 157 |
+
"swin_L_384_22k",
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
pretrain_img_size = int(backbone_cfg.split("_")[-2])
|
| 161 |
+
backbone_cfg = get_swin_config(
|
| 162 |
+
backbone_cfg,
|
| 163 |
+
pretrain_img_size,
|
| 164 |
+
out_indices=tuple(return_interm_indices),
|
| 165 |
+
dilation=dilation,
|
| 166 |
+
use_checkpoint=use_checkpoint,
|
| 167 |
+
)
|
| 168 |
+
backbone = build_backbone(backbone_cfg)
|
| 169 |
+
|
| 170 |
+
# freeze some layers
|
| 171 |
+
if backbone_freeze_keywords is not None:
|
| 172 |
+
for name, parameter in backbone.named_parameters():
|
| 173 |
+
for keyword in backbone_freeze_keywords:
|
| 174 |
+
if keyword in name:
|
| 175 |
+
parameter.requires_grad_(False)
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
# load checkpoint
|
| 179 |
+
if backbone_ckpt_path is not None:
|
| 180 |
+
print("Loading backbone checkpoint from {}".format(backbone_ckpt_path))
|
| 181 |
+
checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"]
|
| 182 |
+
from collections import OrderedDict
|
| 183 |
+
|
| 184 |
+
def key_select_function(keyname):
|
| 185 |
+
if "head" in keyname:
|
| 186 |
+
return False
|
| 187 |
+
if dilation and "layers.3" in keyname:
|
| 188 |
+
return False
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
_tmp_st = OrderedDict(
|
| 192 |
+
{
|
| 193 |
+
k: v
|
| 194 |
+
for k, v in clean_state_dict(checkpoint).items()
|
| 195 |
+
if key_select_function(k)
|
| 196 |
+
}
|
| 197 |
+
)
|
| 198 |
+
_tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
|
| 199 |
+
print(str(_tmp_st_output))
|
| 200 |
+
|
| 201 |
+
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
|
| 202 |
+
assert len(bb_num_channels) == len(
|
| 203 |
+
return_interm_indices
|
| 204 |
+
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
|
| 205 |
+
|
| 206 |
+
model = Joiner(backbone, pos_embedding)
|
| 207 |
+
model.num_channels = bb_num_channels
|
| 208 |
+
self.num_channels = bb_num_channels
|
| 209 |
+
self.model = model
|
| 210 |
+
|
| 211 |
+
def forward(
|
| 212 |
+
self, tensor_list: NestedTensor
|
| 213 |
+
) -> Union[List[NestedTensor], List[torch.Tensor]]:
|
| 214 |
+
"""Forward function.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
[List[NestedTensor]: A list of feature map in NestedTensor format.
|
| 221 |
+
List[torch.Tensor]: A list of position encoding.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
return self.model(tensor_list)
|
| 225 |
+
|
| 226 |
+
def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor:
|
| 227 |
+
"""Forward function to get position embedding only.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
torch.Tensor: Position embedding.
|
| 234 |
+
"""
|
| 235 |
+
return self.model.forward_pos_embed_only(tensor_list)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw):
|
| 239 |
+
"""Get swin config dict.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
modelname (str): Name of the model.
|
| 243 |
+
pretrain_img_size (Tuple[int, int]): Image size of the pretrain model.
|
| 244 |
+
kw (Dict): Other key word arguments.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Dict: Config dict.
|
| 248 |
+
str: Path to the pretrained checkpoint.
|
| 249 |
+
"""
|
| 250 |
+
assert modelname in [
|
| 251 |
+
"swin_T_224_1k",
|
| 252 |
+
"swin_B_224_22k",
|
| 253 |
+
"swin_B_384_22k",
|
| 254 |
+
"swin_L_224_22k",
|
| 255 |
+
"swin_L_384_22k",
|
| 256 |
+
]
|
| 257 |
+
model_para_dict = {
|
| 258 |
+
"swin_T_224_1k": dict(
|
| 259 |
+
type="SwinTransformer",
|
| 260 |
+
embed_dim=96,
|
| 261 |
+
depths=[2, 2, 6, 2],
|
| 262 |
+
num_heads=[3, 6, 12, 24],
|
| 263 |
+
window_size=7,
|
| 264 |
+
),
|
| 265 |
+
"swin_B_224_22k": dict(
|
| 266 |
+
type="SwinTransformer",
|
| 267 |
+
embed_dim=128,
|
| 268 |
+
depths=[2, 2, 18, 2],
|
| 269 |
+
num_heads=[4, 8, 16, 32],
|
| 270 |
+
window_size=7,
|
| 271 |
+
),
|
| 272 |
+
"swin_B_384_22k": dict(
|
| 273 |
+
type="SwinTransformer",
|
| 274 |
+
embed_dim=128,
|
| 275 |
+
depths=[2, 2, 18, 2],
|
| 276 |
+
num_heads=[4, 8, 16, 32],
|
| 277 |
+
window_size=12,
|
| 278 |
+
),
|
| 279 |
+
"swin_L_224_22k": dict(
|
| 280 |
+
type="SwinTransformer",
|
| 281 |
+
embed_dim=192,
|
| 282 |
+
depths=[2, 2, 18, 2],
|
| 283 |
+
num_heads=[6, 12, 24, 48],
|
| 284 |
+
window_size=7,
|
| 285 |
+
),
|
| 286 |
+
"swin_L_384_22k": dict(
|
| 287 |
+
type="SwinTransformer",
|
| 288 |
+
embed_dim=192,
|
| 289 |
+
depths=[2, 2, 18, 2],
|
| 290 |
+
num_heads=[6, 12, 24, 48],
|
| 291 |
+
window_size=12,
|
| 292 |
+
),
|
| 293 |
+
}
|
| 294 |
+
kw_cgf = model_para_dict[modelname]
|
| 295 |
+
kw_cgf.update(kw)
|
| 296 |
+
kw_cgf.update(dict(pretrain_img_size=pretrain_img_size))
|
| 297 |
+
return kw_cgf
|
detect_tools/upn/models/decoder/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .upn_decoder import UPNDecoder, DeformableTransformerDecoderLayer
|
| 2 |
+
|
| 3 |
+
__all__ = ["UPNDecoder", "DeformableTransformerDecoderLayer"]
|
detect_tools/upn/models/decoder/upn_decoder.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from detect_tools.upn import DECODERS, build_decoder
|
| 7 |
+
from detect_tools.upn.models.module import MLP
|
| 8 |
+
from detect_tools.upn.models.utils import (gen_sineembed_for_position,
|
| 9 |
+
get_activation_fn, get_clones,
|
| 10 |
+
inverse_sigmoid)
|
| 11 |
+
from detect_tools.upn.ops.modules import MSDeformAttn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@DECODERS.register_module()
|
| 15 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
| 16 |
+
"""Deformable Transformer Decoder Layer. This is a modified version in Grounding DINO.
|
| 17 |
+
After the query is attented to the image feature, it is further attented to the text feature.
|
| 18 |
+
The execute order is: self_attn -> cross_attn to text -> cross_attn to image -> ffn
|
| 19 |
+
Args:
|
| 20 |
+
d_model (int): The dimension of keys/values/queries in :class:`MultiheadAttention`.
|
| 21 |
+
d_ffn (int): The dimension of the feedforward network model.
|
| 22 |
+
dropout (float): Probability of an element to be zeroed.
|
| 23 |
+
activation (str): Activation function in the feedforward network.
|
| 24 |
+
'relu' and 'gelu' are supported.
|
| 25 |
+
n_levels (int): The number of levels in Multi-Scale Deformable Attention.
|
| 26 |
+
n_heads (int): Parallel attention heads.
|
| 27 |
+
n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
|
| 28 |
+
ffn_extra_layernorm (bool): If True, add an extra layernorm after ffn.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
d_model: int = 256,
|
| 34 |
+
d_ffn: int = 1024,
|
| 35 |
+
dropout: float = 0.1,
|
| 36 |
+
activation: str = "relu",
|
| 37 |
+
n_levels: int = 4,
|
| 38 |
+
n_heads: int = 8,
|
| 39 |
+
n_points: int = 4,
|
| 40 |
+
ffn_extra_layernorm: bool = False,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
# cross attention for visual features
|
| 45 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
| 46 |
+
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 47 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 48 |
+
|
| 49 |
+
# self attention for query
|
| 50 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
| 51 |
+
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 52 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 53 |
+
|
| 54 |
+
# ffn
|
| 55 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
| 56 |
+
self.activation = get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
|
| 57 |
+
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 58 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
| 59 |
+
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 60 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 61 |
+
if ffn_extra_layernorm:
|
| 62 |
+
raise NotImplementedError("ffn_extra_layernorm not implemented")
|
| 63 |
+
self.norm_ext = nn.LayerNorm(d_ffn)
|
| 64 |
+
else:
|
| 65 |
+
self.norm_ext = None
|
| 66 |
+
|
| 67 |
+
self.key_aware_proj = None
|
| 68 |
+
|
| 69 |
+
def rm_self_attn_modules(self):
|
| 70 |
+
self.self_attn = None
|
| 71 |
+
self.dropout2 = None
|
| 72 |
+
self.norm2 = None
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def with_pos_embed(tensor, pos):
|
| 76 |
+
return tensor if pos is None else tensor + pos
|
| 77 |
+
|
| 78 |
+
def forward_ffn(self, tgt):
|
| 79 |
+
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
| 80 |
+
|
| 81 |
+
tgt = tgt + self.dropout4(tgt2)
|
| 82 |
+
tgt = self.norm3(tgt)
|
| 83 |
+
return tgt
|
| 84 |
+
|
| 85 |
+
def forward(
|
| 86 |
+
self,
|
| 87 |
+
tgt: torch.Tensor,
|
| 88 |
+
tgt_query_pos: torch.Tensor = None,
|
| 89 |
+
tgt_reference_points: torch.Tensor = None,
|
| 90 |
+
memory: torch.Tensor = None,
|
| 91 |
+
memory_key_padding_mask: torch.Tensor = None,
|
| 92 |
+
memory_level_start_index: torch.Tensor = None,
|
| 93 |
+
memory_spatial_shapes: torch.Tensor = None,
|
| 94 |
+
self_attn_mask: torch.Tensor = None,
|
| 95 |
+
cross_attn_mask: torch.Tensor = None,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""Forward function
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
tgt (torch.Tensor): Input target in shape (B, T, C)
|
| 101 |
+
tgt_query_pos (torch.Tensor): Positional encoding of the query.
|
| 102 |
+
tgt_query_sine_embed (torch.Tensor): Sine positional encoding of the query. Unused.
|
| 103 |
+
tgt_key_padding_mask (torch.Tensor): Mask for target feature in shape (B, T).
|
| 104 |
+
tgt_reference_points (torch.Tensor): Reference points for the query in shape (B, T, 4).
|
| 105 |
+
memory_text (torch.Tensor): Input text embeddings in shape (B, num_token, C).
|
| 106 |
+
text_attention_mask (torch.Tensor): Attention mask for text embeddings in shape
|
| 107 |
+
(B, num_token).
|
| 108 |
+
memory (torch.Tensor): Input image feature in shape (B, HW, C)
|
| 109 |
+
memory_key_padding_mask (torch.Tensor): Mask for image feature in shape (B, HW)
|
| 110 |
+
memory_level_start_index (torch.Tensor): Starting index of each level in memory.
|
| 111 |
+
memory_spatial_shapes (torch.Tensor): Spatial shape of each level in memory.
|
| 112 |
+
memory_pos (torch.Tensor): Positional encoding of memory. Unused.
|
| 113 |
+
self_attn_mask (torch.Tensor): Mask used for self-attention.
|
| 114 |
+
cross_attn_mask (torch.Tensor): Mask used for cross-attention.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
torch.Tensor: Output tensor in shape (B, T, C)
|
| 118 |
+
"""
|
| 119 |
+
assert cross_attn_mask is None
|
| 120 |
+
|
| 121 |
+
# self attention
|
| 122 |
+
if self.self_attn is not None:
|
| 123 |
+
q = k = self.with_pos_embed(tgt, tgt_query_pos)
|
| 124 |
+
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
|
| 125 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 126 |
+
tgt = self.norm2(tgt)
|
| 127 |
+
|
| 128 |
+
# attend to image features
|
| 129 |
+
tgt2 = self.cross_attn(
|
| 130 |
+
self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
|
| 131 |
+
tgt_reference_points.transpose(0, 1).contiguous(),
|
| 132 |
+
memory.transpose(0, 1),
|
| 133 |
+
memory_spatial_shapes,
|
| 134 |
+
memory_level_start_index,
|
| 135 |
+
memory_key_padding_mask,
|
| 136 |
+
).transpose(0, 1)
|
| 137 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 138 |
+
tgt = self.norm1(tgt)
|
| 139 |
+
# ffn
|
| 140 |
+
tgt = self.forward_ffn(tgt)
|
| 141 |
+
|
| 142 |
+
return tgt
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@DECODERS.register_module()
|
| 146 |
+
class UPNDecoder(nn.Module):
|
| 147 |
+
"""Decoder used in UPN. Each layer is a DeformableTransformerDecoderLayer. The query
|
| 148 |
+
will be abled to attend the image feature and text feature. The execute order is:
|
| 149 |
+
self_attn -> cross_attn to image -> ffn
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
decoder_layer_cfg (Dict): Config for the DeformableTransformerDecoderLayer.
|
| 153 |
+
num_layers (int): number of layers
|
| 154 |
+
norm (nn.Module, optional): normalization layer. Defaults to None.
|
| 155 |
+
return_intermediate (bool, optional): whether return intermediate results.
|
| 156 |
+
Defaults to False.
|
| 157 |
+
d_model (int, optional): dimension of the model. Defaults to 256.
|
| 158 |
+
query_dim (int, optional): dimension of the query. Defaults to 4.
|
| 159 |
+
modulate_hw_attn (bool, optional): whether modulate the attention weights
|
| 160 |
+
by the height and width of the image feature. Defaults to False.
|
| 161 |
+
num_feature_levels (int, optional): number of feature levels. Defaults to 1.
|
| 162 |
+
deformable_decoder (bool, optional): whether use deformable decoder. Defaults to False.
|
| 163 |
+
decoder_query_perturber ([type], optional): [description]. Defaults to None.
|
| 164 |
+
dec_layer_number ([type], optional): [description]. Defaults to None.
|
| 165 |
+
rm_dec_query_scale (bool, optional): [description]. Defaults to False.
|
| 166 |
+
dec_layer_share (bool, optional): [description]. Defaults to False.
|
| 167 |
+
dec_layer_dropout_prob ([type], optional): [description]. Defaults to None.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
decoder_layer_cfg: Dict,
|
| 173 |
+
num_layers: int,
|
| 174 |
+
norm: str = "layernorm",
|
| 175 |
+
return_intermediate: bool = True,
|
| 176 |
+
d_model: int = 256,
|
| 177 |
+
query_dim: int = 4,
|
| 178 |
+
modulate_hw_attn: bool = False,
|
| 179 |
+
num_feature_levels: int = 1,
|
| 180 |
+
deformable_decoder: bool = True,
|
| 181 |
+
decoder_query_perturber=None,
|
| 182 |
+
dec_layer_number=None,
|
| 183 |
+
rm_dec_query_scale: bool = True,
|
| 184 |
+
dec_layer_share: bool = False,
|
| 185 |
+
dec_layer_dropout_prob=None,
|
| 186 |
+
use_detached_boxes_dec_out: bool = False,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
|
| 190 |
+
decoder_layer = build_decoder(decoder_layer_cfg)
|
| 191 |
+
if num_layers > 0:
|
| 192 |
+
self.layers = get_clones(
|
| 193 |
+
decoder_layer, num_layers, layer_share=dec_layer_share
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
self.layers = []
|
| 197 |
+
self.num_layers = num_layers
|
| 198 |
+
if norm == "layernorm":
|
| 199 |
+
self.norm = nn.LayerNorm(d_model)
|
| 200 |
+
self.return_intermediate = return_intermediate
|
| 201 |
+
self.query_dim = query_dim
|
| 202 |
+
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
|
| 203 |
+
self.num_feature_levels = num_feature_levels
|
| 204 |
+
self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
|
| 205 |
+
|
| 206 |
+
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
|
| 207 |
+
self.ref_point_head_point = MLP(
|
| 208 |
+
d_model, d_model, d_model, 2
|
| 209 |
+
) # for point reference only
|
| 210 |
+
if not deformable_decoder:
|
| 211 |
+
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
|
| 212 |
+
else:
|
| 213 |
+
self.query_pos_sine_scale = None
|
| 214 |
+
|
| 215 |
+
if rm_dec_query_scale:
|
| 216 |
+
self.query_scale = None
|
| 217 |
+
else:
|
| 218 |
+
raise NotImplementedError
|
| 219 |
+
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
| 220 |
+
self.bbox_embed = None
|
| 221 |
+
self.class_embed = None
|
| 222 |
+
|
| 223 |
+
self.d_model = d_model
|
| 224 |
+
self.modulate_hw_attn = modulate_hw_attn
|
| 225 |
+
self.deformable_decoder = deformable_decoder
|
| 226 |
+
|
| 227 |
+
if not deformable_decoder and modulate_hw_attn:
|
| 228 |
+
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
|
| 229 |
+
else:
|
| 230 |
+
self.ref_anchor_head = None
|
| 231 |
+
|
| 232 |
+
self.decoder_query_perturber = decoder_query_perturber
|
| 233 |
+
self.box_pred_damping = None
|
| 234 |
+
|
| 235 |
+
self.dec_layer_number = dec_layer_number
|
| 236 |
+
if dec_layer_number is not None:
|
| 237 |
+
assert isinstance(dec_layer_number, list)
|
| 238 |
+
assert len(dec_layer_number) == num_layers
|
| 239 |
+
|
| 240 |
+
self.dec_layer_dropout_prob = dec_layer_dropout_prob
|
| 241 |
+
if dec_layer_dropout_prob is not None:
|
| 242 |
+
assert isinstance(dec_layer_dropout_prob, list)
|
| 243 |
+
assert len(dec_layer_dropout_prob) == num_layers
|
| 244 |
+
for i in dec_layer_dropout_prob:
|
| 245 |
+
assert 0.0 <= i <= 1.0
|
| 246 |
+
|
| 247 |
+
self.rm_detach = None
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
tgt: torch.Tensor,
|
| 252 |
+
memory: torch.Tensor,
|
| 253 |
+
tgt_mask: torch.Tensor = None,
|
| 254 |
+
memory_mask: torch.Tensor = None,
|
| 255 |
+
tgt_key_padding_mask: torch.Tensor = None,
|
| 256 |
+
memory_key_padding_mask: torch.Tensor = None,
|
| 257 |
+
pos: torch.Tensor = None,
|
| 258 |
+
refpoints_unsigmoid: torch.Tensor = None,
|
| 259 |
+
level_start_index: torch.Tensor = None,
|
| 260 |
+
spatial_shapes: torch.Tensor = None,
|
| 261 |
+
valid_ratios: torch.Tensor = None,
|
| 262 |
+
memory_ref_image: torch.Tensor = None,
|
| 263 |
+
refImg_padding_mask: torch.Tensor = None,
|
| 264 |
+
memory_visual_prompt: torch.Tensor = None,
|
| 265 |
+
):
|
| 266 |
+
"""Forward function.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
tgt (torch.Tensor): target feature, [bs, num_queries, d_model]
|
| 270 |
+
memory (torch.Tensor): Image feature, [bs, hw, d_model]
|
| 271 |
+
tgt_mask (torch.Tensor, optional): target mask for attention. Defaults to None.
|
| 272 |
+
memory_mask (torch.Tensor, optional): image mask for attention. Defaults to None.
|
| 273 |
+
tgt_key_padding_mask (torch.Tensor, optional): target mask for padding. Defaults to None.
|
| 274 |
+
memory_key_padding_mask (torch.Tensor, optional): image mask for padding. Defaults to None.
|
| 275 |
+
pos (torch.Tensor, optional): query position embedding
|
| 276 |
+
refpoints_unsigmoid (torch.Tensor, optional): reference points. Defaults to None.
|
| 277 |
+
level_start_index (torch.Tensor, optional): start index of each level. Defaults to None.
|
| 278 |
+
spatial_shapes (torch.Tensor, optional): spatial shape of each level. Defaults to None.
|
| 279 |
+
valid_ratios (torch.Tensor, optional): valid ratio of each level. Defaults to None.
|
| 280 |
+
memory_ref_image (torch.Tensor, optional): reference image feature, [bs, num_ref, d_model]. Defaults to None.
|
| 281 |
+
refImg_padding_mask (torch.Tensor, optional): padding mask for attention. Defaults to None.
|
| 282 |
+
"""
|
| 283 |
+
output = tgt
|
| 284 |
+
|
| 285 |
+
intermediate = []
|
| 286 |
+
reference_points = refpoints_unsigmoid.sigmoid()
|
| 287 |
+
ref_points = [reference_points]
|
| 288 |
+
|
| 289 |
+
for layer_id, layer in enumerate(self.layers):
|
| 290 |
+
|
| 291 |
+
if reference_points.shape[-1] == 4:
|
| 292 |
+
reference_points_input = (
|
| 293 |
+
reference_points[:, :, None]
|
| 294 |
+
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
|
| 295 |
+
) # nq, bs, nlevel, 4
|
| 296 |
+
else:
|
| 297 |
+
assert reference_points.shape[-1] == 2
|
| 298 |
+
reference_points_input = (
|
| 299 |
+
reference_points[:, :, None] * valid_ratios[None, :]
|
| 300 |
+
)
|
| 301 |
+
query_sine_embed = gen_sineembed_for_position(
|
| 302 |
+
reference_points_input[:, :, 0, :]
|
| 303 |
+
) # nq, bs, 256*2
|
| 304 |
+
|
| 305 |
+
# conditional query
|
| 306 |
+
if query_sine_embed.shape[-1] == 512:
|
| 307 |
+
raw_query_pos = (
|
| 308 |
+
self.ref_point_head(query_sine_embed)
|
| 309 |
+
+ self.ref_point_head_point(
|
| 310 |
+
torch.zeros_like(query_sine_embed)[:, :, :256]
|
| 311 |
+
)
|
| 312 |
+
* 0.0
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
raw_query_pos = (
|
| 316 |
+
self.ref_point_head_point(query_sine_embed)
|
| 317 |
+
+ self.ref_point_head(
|
| 318 |
+
torch.zeros(
|
| 319 |
+
query_sine_embed.shape[0],
|
| 320 |
+
query_sine_embed.shape[1],
|
| 321 |
+
512,
|
| 322 |
+
device=query_sine_embed.device,
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
* 0.0
|
| 326 |
+
)
|
| 327 |
+
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
|
| 328 |
+
query_pos = pos_scale * raw_query_pos
|
| 329 |
+
|
| 330 |
+
# main process
|
| 331 |
+
output = layer(
|
| 332 |
+
tgt=output,
|
| 333 |
+
tgt_query_pos=query_pos,
|
| 334 |
+
tgt_reference_points=reference_points_input,
|
| 335 |
+
memory=memory,
|
| 336 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 337 |
+
memory_level_start_index=level_start_index,
|
| 338 |
+
memory_spatial_shapes=spatial_shapes,
|
| 339 |
+
self_attn_mask=tgt_mask,
|
| 340 |
+
cross_attn_mask=memory_mask,
|
| 341 |
+
)
|
| 342 |
+
if output.isnan().any() | output.isinf().any():
|
| 343 |
+
print(f"output layer_id {layer_id} is nan")
|
| 344 |
+
try:
|
| 345 |
+
num_nan = output.isnan().sum().item()
|
| 346 |
+
num_inf = output.isinf().sum().item()
|
| 347 |
+
print(f"num_nan {num_nan}, num_inf {num_inf}")
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(e)
|
| 350 |
+
|
| 351 |
+
# iter update
|
| 352 |
+
if self.bbox_embed is not None:
|
| 353 |
+
|
| 354 |
+
reference_before_sigmoid = inverse_sigmoid(reference_points)
|
| 355 |
+
delta_unsig = self.bbox_embed[layer_id](output)
|
| 356 |
+
outputs_unsig = delta_unsig + reference_before_sigmoid
|
| 357 |
+
new_reference_points = outputs_unsig.sigmoid()
|
| 358 |
+
|
| 359 |
+
if self.rm_detach and "dec" in self.rm_detach:
|
| 360 |
+
reference_points = new_reference_points
|
| 361 |
+
else:
|
| 362 |
+
reference_points = new_reference_points.detach()
|
| 363 |
+
|
| 364 |
+
if self.use_detached_boxes_dec_out:
|
| 365 |
+
ref_points.append(reference_points)
|
| 366 |
+
else:
|
| 367 |
+
ref_points.append(new_reference_points)
|
| 368 |
+
|
| 369 |
+
if self.return_intermediate:
|
| 370 |
+
intermediate.append(self.norm(output))
|
| 371 |
+
|
| 372 |
+
if self.return_intermediate:
|
| 373 |
+
return [
|
| 374 |
+
[itm_out.transpose(0, 1) for itm_out in intermediate],
|
| 375 |
+
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
|
| 376 |
+
]
|
| 377 |
+
else:
|
| 378 |
+
return self.norm(output).transpose(0, 1)
|
detect_tools/upn/models/encoder/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .upn_encoder import DeformableTransformerEncoderLayer, UPNEncoder
|
| 2 |
+
|
| 3 |
+
__all__ = ["UPNEncoder", "DeformableTransformerEncoderLayer"]
|
detect_tools/upn/models/encoder/upn_encoder.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.utils.checkpoint as checkpoint
|
| 6 |
+
|
| 7 |
+
from detect_tools.upn import ENCODERS, build_encoder
|
| 8 |
+
from detect_tools.upn.models.utils import get_activation_fn, get_clones
|
| 9 |
+
from detect_tools.upn.ops.modules import MSDeformAttn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@ENCODERS.register_module()
|
| 13 |
+
class DeformableTransformerEncoderLayer(nn.Module):
|
| 14 |
+
"""Deformable Transformer Encoder Layer.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
d_model (int): The dimension of keys/values/queries in
|
| 18 |
+
:class:`MultiheadAttention`.
|
| 19 |
+
d_ffn (int): The dimension of the feedforward network model.
|
| 20 |
+
dropout (float): Probability of an element to be zeroed.
|
| 21 |
+
activation (str): Activation function in the feedforward network.
|
| 22 |
+
'relu' and 'gelu' are supported.
|
| 23 |
+
n_levels (int): The number of levels in Multi-Scale Deformable Attention.
|
| 24 |
+
n_heads (int): Parallel attention heads.
|
| 25 |
+
n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
|
| 26 |
+
add_channel_attention (bool): If True, add channel attention.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
d_model: int = 256,
|
| 32 |
+
d_ffn: int = 1024,
|
| 33 |
+
dropout: float = 0.1,
|
| 34 |
+
activation: str = "relu",
|
| 35 |
+
n_levels: int = 4,
|
| 36 |
+
n_heads: int = 8,
|
| 37 |
+
n_points: int = 4,
|
| 38 |
+
add_channel_attention: bool = False,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
# self attention
|
| 43 |
+
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
| 44 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 45 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 46 |
+
|
| 47 |
+
# ffn
|
| 48 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
| 49 |
+
self.activation = get_activation_fn(activation, d_model=d_ffn)
|
| 50 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 51 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
| 52 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 53 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 54 |
+
|
| 55 |
+
# channel attention
|
| 56 |
+
self.add_channel_attention = add_channel_attention
|
| 57 |
+
if add_channel_attention:
|
| 58 |
+
self.activ_channel = get_activation_fn("dyrelu", d_model=d_model)
|
| 59 |
+
self.norm_channel = nn.LayerNorm(d_model)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def with_pos_embed(tensor, pos):
|
| 63 |
+
return tensor if pos is None else tensor + pos
|
| 64 |
+
|
| 65 |
+
def forward_ffn(self, src: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
| 67 |
+
src = src + self.dropout3(src2)
|
| 68 |
+
src = self.norm2(src)
|
| 69 |
+
return src
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self,
|
| 73 |
+
src: torch.Tensor,
|
| 74 |
+
pos: torch.Tensor,
|
| 75 |
+
reference_points: torch.Tensor,
|
| 76 |
+
spatial_shapes: torch.Tensor,
|
| 77 |
+
level_start_index: torch.Tensor,
|
| 78 |
+
key_padding_mask: torch.Tensor = None,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
"""Forward function for `DeformableTransformerEncoderLayer`.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
src (torch.Tensor): The input sequence of shape (S, N, E).
|
| 84 |
+
pos (torch.Tensor): The position embedding of shape (S, N, E).
|
| 85 |
+
reference_points (torch.Tensor): The reference points of shape (N, L, 2).
|
| 86 |
+
spatial_shapes (torch.Tensor): The spatial shapes of feature levels.
|
| 87 |
+
level_start_index (torch.Tensor): The start index of each level.
|
| 88 |
+
key_padding_mask (torch.Tensor): The mask for keys with shape (N, S).
|
| 89 |
+
"""
|
| 90 |
+
# self attention
|
| 91 |
+
# import ipdb; ipdb.set_trace()
|
| 92 |
+
src2 = self.self_attn(
|
| 93 |
+
self.with_pos_embed(src, pos),
|
| 94 |
+
reference_points,
|
| 95 |
+
src,
|
| 96 |
+
spatial_shapes,
|
| 97 |
+
level_start_index,
|
| 98 |
+
key_padding_mask,
|
| 99 |
+
)
|
| 100 |
+
src = src + self.dropout1(src2)
|
| 101 |
+
src = self.norm1(src)
|
| 102 |
+
|
| 103 |
+
# ffn
|
| 104 |
+
src = self.forward_ffn(src)
|
| 105 |
+
|
| 106 |
+
# channel attn
|
| 107 |
+
if self.add_channel_attention:
|
| 108 |
+
src = self.norm_channel(src + self.activ_channel(src))
|
| 109 |
+
|
| 110 |
+
return src
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@ENCODERS.register_module()
|
| 114 |
+
class UPNEncoder(nn.Module):
|
| 115 |
+
"""Implementation of UPN Encoder.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
num_layers (int): The number of layers in the TransformerEncoder.
|
| 119 |
+
d_model (int, optional): The dimension of the input feature. Defaults to 256.
|
| 120 |
+
encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer.
|
| 121 |
+
use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for
|
| 122 |
+
memory saving. Defaults to False.
|
| 123 |
+
use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder.
|
| 124 |
+
enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer.
|
| 125 |
+
Defaults to False. This is used for all the sub-layers in the basic block.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
num_layers: int,
|
| 131 |
+
d_model: int = 256,
|
| 132 |
+
encoder_layer_cfg: Dict = None,
|
| 133 |
+
use_checkpoint: bool = True,
|
| 134 |
+
use_transformer_ckpt: bool = True,
|
| 135 |
+
enc_layer_share: bool = False,
|
| 136 |
+
multi_level_encoder_fusion: str = None,
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
# prepare layers
|
| 140 |
+
self.layers = []
|
| 141 |
+
self.refImg_layers = []
|
| 142 |
+
self.fusion_layers = []
|
| 143 |
+
encoder_layer = build_encoder(encoder_layer_cfg)
|
| 144 |
+
|
| 145 |
+
self.multi_level_encoder_fusion = multi_level_encoder_fusion
|
| 146 |
+
self._initilize_memory_fusion_layers(
|
| 147 |
+
multi_level_encoder_fusion, num_layers, d_model
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if num_layers > 0:
|
| 151 |
+
self.layers = get_clones(
|
| 152 |
+
encoder_layer, num_layers, layer_share=enc_layer_share
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
self.layers = []
|
| 156 |
+
del encoder_layer
|
| 157 |
+
|
| 158 |
+
self.query_scale = None
|
| 159 |
+
self.num_layers = num_layers
|
| 160 |
+
self.d_model = d_model
|
| 161 |
+
|
| 162 |
+
self.use_checkpoint = use_checkpoint
|
| 163 |
+
self.use_transformer_ckpt = use_transformer_ckpt
|
| 164 |
+
|
| 165 |
+
def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model):
|
| 166 |
+
if fusion_type is None:
|
| 167 |
+
self.memory_fusion_layer = None
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"]
|
| 171 |
+
if fusion_type == "stable_dense_fusion":
|
| 172 |
+
self.memory_fusion_layer = nn.Sequential(
|
| 173 |
+
nn.Linear(d_model * (num_layers + 1), d_model),
|
| 174 |
+
nn.LayerNorm(d_model),
|
| 175 |
+
)
|
| 176 |
+
nn.init.constant_(self.memory_fusion_layer[0].bias, 0)
|
| 177 |
+
elif fusion_type == "dense_net_fusion":
|
| 178 |
+
self.memory_fusion_layer = nn.ModuleList()
|
| 179 |
+
for i in range(num_layers):
|
| 180 |
+
self.memory_fusion_layer.append(
|
| 181 |
+
nn.Sequential(
|
| 182 |
+
nn.Linear(
|
| 183 |
+
d_model * (i + 2), d_model
|
| 184 |
+
), # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256
|
| 185 |
+
nn.LayerNorm(d_model),
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
+
for layer in self.memory_fusion_layer:
|
| 189 |
+
nn.init.constant_(layer[0].bias, 0)
|
| 190 |
+
else:
|
| 191 |
+
raise NotImplementedError
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
| 195 |
+
reference_points_list = []
|
| 196 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
| 197 |
+
|
| 198 |
+
ref_y, ref_x = torch.meshgrid(
|
| 199 |
+
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
| 200 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
| 201 |
+
)
|
| 202 |
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
| 203 |
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
| 204 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
| 205 |
+
reference_points_list.append(ref)
|
| 206 |
+
reference_points = torch.cat(reference_points_list, 1)
|
| 207 |
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
| 208 |
+
return reference_points
|
| 209 |
+
|
| 210 |
+
def forward(
|
| 211 |
+
self,
|
| 212 |
+
src: torch.Tensor,
|
| 213 |
+
pos: torch.Tensor,
|
| 214 |
+
spatial_shapes: torch.Tensor,
|
| 215 |
+
level_start_index: torch.Tensor,
|
| 216 |
+
valid_ratios: torch.Tensor,
|
| 217 |
+
key_padding_mask: torch.Tensor = None,
|
| 218 |
+
):
|
| 219 |
+
"""Forward function
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256]
|
| 223 |
+
pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256]
|
| 224 |
+
spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2]
|
| 225 |
+
level_start_index (torch.Tensor): Start index of each level in shape [num_level]
|
| 226 |
+
valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2]
|
| 227 |
+
key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)]
|
| 228 |
+
memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults
|
| 229 |
+
to None.
|
| 230 |
+
refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature
|
| 231 |
+
in shape [bs, n_text]. Defaults to None.
|
| 232 |
+
pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape
|
| 233 |
+
[bs, n_ref, 256]. Defaults to None.
|
| 234 |
+
refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference
|
| 235 |
+
image feature in shape [bs, n_ref, n_ref]. Defaults to None.
|
| 236 |
+
Outpus:
|
| 237 |
+
torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256]
|
| 238 |
+
torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256]
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
output = src
|
| 242 |
+
# preparation and reshape
|
| 243 |
+
if self.num_layers > 0:
|
| 244 |
+
reference_points = self.get_reference_points(
|
| 245 |
+
spatial_shapes, valid_ratios, device=src.device
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# multi-level dense fusion
|
| 249 |
+
output_list = [output]
|
| 250 |
+
# main process
|
| 251 |
+
for layer_id, layer in enumerate(self.layers):
|
| 252 |
+
# main process
|
| 253 |
+
if self.use_transformer_ckpt:
|
| 254 |
+
output = checkpoint.checkpoint(
|
| 255 |
+
layer,
|
| 256 |
+
output,
|
| 257 |
+
pos,
|
| 258 |
+
reference_points,
|
| 259 |
+
spatial_shapes,
|
| 260 |
+
level_start_index,
|
| 261 |
+
key_padding_mask,
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
output = layer(
|
| 265 |
+
src=output,
|
| 266 |
+
pos=pos,
|
| 267 |
+
reference_points=reference_points,
|
| 268 |
+
spatial_shapes=spatial_shapes,
|
| 269 |
+
level_start_index=level_start_index,
|
| 270 |
+
key_padding_mask=key_padding_mask,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
output_list.append(output)
|
| 274 |
+
if (
|
| 275 |
+
self.multi_level_encoder_fusion is not None
|
| 276 |
+
and self.multi_level_encoder_fusion == "dense_net_fusion"
|
| 277 |
+
):
|
| 278 |
+
output = self.memory_fusion_layer[layer_id](
|
| 279 |
+
torch.cat(output_list, dim=-1)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if (
|
| 283 |
+
self.multi_level_encoder_fusion is not None
|
| 284 |
+
and self.multi_level_encoder_fusion == "stable_dense_fusion"
|
| 285 |
+
):
|
| 286 |
+
output = self.memory_fusion_layer(torch.cat(output_list, dim=-1))
|
| 287 |
+
|
| 288 |
+
return output
|
detect_tools/upn/models/module/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .contrastive import ContrastiveAssign
|
| 2 |
+
from .mlp import MLP
|
| 3 |
+
from .nested_tensor import NestedTensor, nested_tensor_from_tensor_list
|
| 4 |
+
|
| 5 |
+
__all__ = ["MLP", "NestedTensor", "nested_tensor_from_tensor_list", "ContrastiveAssign"]
|
detect_tools/upn/models/module/contrastive.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ContrastiveAssign(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
cal_bias: nn.Module = None,
|
| 13 |
+
) -> None:
|
| 14 |
+
"""Lanuage-Image Contrastive Assignment used to calculate the similarity between
|
| 15 |
+
the text and the image.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
cal_bias (nn.Module, optional): The bias used to calculate the similarity.
|
| 19 |
+
Defaults to None.
|
| 20 |
+
max_text_len (int, optional): The max length of the text. Defaults to 256.
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.cal_bias = cal_bias
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor, ref_dict: Dict):
|
| 26 |
+
|
| 27 |
+
y = ref_dict["encoded_ref_feature"]
|
| 28 |
+
res = x @ y.transpose(-1, -2)
|
| 29 |
+
return res
|
detect_tools/upn/models/module/mlp.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MLP(nn.Module):
|
| 6 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.num_layers = num_layers
|
| 11 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 12 |
+
self.layers = nn.ModuleList(
|
| 13 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
for i, layer in enumerate(self.layers):
|
| 17 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 18 |
+
return x
|
detect_tools/upn/models/module/nested_tensor.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NestedTensor(object):
|
| 8 |
+
"""Define a NestedTensor class
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
tensors (torch.Tensor): Tensor with shape [batch, C, H, W] or [C, H, W]
|
| 12 |
+
mask (Union[torch.Tensor, str]): mask with shape [batch, H, W] or [H, W]. If mask
|
| 13 |
+
is 'auto', it will be generated automatically by summing the tensor along
|
| 14 |
+
the channel dimension. Mask is used to indicate the padding area.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self, tensors: torch.Tensor, mask: Union[torch.Tensor, str] = "auto"
|
| 19 |
+
) -> None:
|
| 20 |
+
self.tensors = tensors
|
| 21 |
+
self.mask = mask
|
| 22 |
+
if mask == "auto":
|
| 23 |
+
self.mask = torch.zeros_like(tensors).to(tensors.device)
|
| 24 |
+
if self.mask.dim() == 3:
|
| 25 |
+
self.mask = self.mask.sum(0).to(bool)
|
| 26 |
+
elif self.mask.dim() == 4:
|
| 27 |
+
self.mask = self.mask.sum(1).to(bool)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
"tensors dim must be 3 or 4 but {}({})".format(
|
| 31 |
+
self.tensors.dim(), self.tensors.shape
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def imgsize(self) -> List[torch.Tensor]:
|
| 36 |
+
"""get the img size of the tensor
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
list[torch.Tensor]: list of tensor with shape [2] which is [H, W]
|
| 40 |
+
"""
|
| 41 |
+
res = []
|
| 42 |
+
for i in range(self.tensors.shape[0]):
|
| 43 |
+
mask = self.mask[i]
|
| 44 |
+
maxH = (~mask).sum(0).max()
|
| 45 |
+
maxW = (~mask).sum(1).max()
|
| 46 |
+
res.append(torch.Tensor([maxH, maxW]))
|
| 47 |
+
return res
|
| 48 |
+
|
| 49 |
+
def to(self, device: torch.device):
|
| 50 |
+
"""Move tensors and mask to the given device
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
device (torch.device): device to move
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
NestedTensor: moved NestedTensor
|
| 57 |
+
"""
|
| 58 |
+
cast_tensor = self.tensors.to(device)
|
| 59 |
+
mask = self.mask
|
| 60 |
+
if mask is not None:
|
| 61 |
+
assert mask is not None
|
| 62 |
+
cast_mask = mask.to(device)
|
| 63 |
+
else:
|
| 64 |
+
cast_mask = None
|
| 65 |
+
return NestedTensor(cast_tensor, cast_mask)
|
| 66 |
+
|
| 67 |
+
def to_img_list_single(
|
| 68 |
+
self, tensor: torch.Tensor, mask: torch.Tensor
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
"""remove the padding for one image
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tensor (torch.Tensor): tensor with shape [C, H, W]
|
| 74 |
+
mask (torch.Tensor): mask with shape [H, W]
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor: tensor with shape [C, maxH, maxW]
|
| 78 |
+
"""
|
| 79 |
+
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(
|
| 80 |
+
tensor.dim()
|
| 81 |
+
)
|
| 82 |
+
maxH = (~mask).sum(0).max()
|
| 83 |
+
maxW = (~mask).sum(1).max()
|
| 84 |
+
img = tensor[:, :maxH, :maxW]
|
| 85 |
+
return img
|
| 86 |
+
|
| 87 |
+
def to_img_list(self) -> List[torch.Tensor]:
|
| 88 |
+
"""remove the padding and convert to img list
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
list[torch.Tensor]: list of tensor with shape [C, maxH, maxW]
|
| 92 |
+
"""
|
| 93 |
+
if self.tensors.dim() == 3:
|
| 94 |
+
return self.to_img_list_single(self.tensors, self.mask)
|
| 95 |
+
else:
|
| 96 |
+
res = []
|
| 97 |
+
for i in range(self.tensors.shape[0]):
|
| 98 |
+
tensor_i = self.tensors[i]
|
| 99 |
+
mask_i = self.mask[i]
|
| 100 |
+
res.append(self.to_img_list_single(tensor_i, mask_i))
|
| 101 |
+
return res
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def device(self):
|
| 105 |
+
return self.tensors.device
|
| 106 |
+
|
| 107 |
+
def decompose(self):
|
| 108 |
+
return self.tensors, self.mask
|
| 109 |
+
|
| 110 |
+
def __repr__(self):
|
| 111 |
+
return str(self.tensors)
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def shape(self):
|
| 115 |
+
return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _max_by_axis(the_list):
|
| 119 |
+
# type: (List[List[int]]) -> List[int]
|
| 120 |
+
maxes = the_list[0]
|
| 121 |
+
for sublist in the_list[1:]:
|
| 122 |
+
for index, item in enumerate(sublist):
|
| 123 |
+
maxes[index] = max(maxes[index], item)
|
| 124 |
+
return maxes
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def nested_tensor_from_tensor_list(
|
| 128 |
+
tensor_list: List[torch.Tensor], fixed_img_size=None
|
| 129 |
+
):
|
| 130 |
+
if fixed_img_size is not None:
|
| 131 |
+
if isinstance(fixed_img_size, (list, tuple)):
|
| 132 |
+
assert (
|
| 133 |
+
len(fixed_img_size) == 2
|
| 134 |
+
), "image size should be a tuple or list with two elements"
|
| 135 |
+
elif isinstance(fixed_img_size, int):
|
| 136 |
+
fixed_img_size = [fixed_img_size, fixed_img_size]
|
| 137 |
+
|
| 138 |
+
if tensor_list[0].ndim == 3:
|
| 139 |
+
if torchvision._is_tracing():
|
| 140 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
| 141 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
| 142 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
| 143 |
+
|
| 144 |
+
# TODO make it support different-sized images
|
| 145 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
| 146 |
+
|
| 147 |
+
if fixed_img_size is not None:
|
| 148 |
+
c, orig_h, orig_w = max_size
|
| 149 |
+
assert (
|
| 150 |
+
orig_h <= fixed_img_size[0] and orig_w <= fixed_img_size[1]
|
| 151 |
+
), f"{orig_h} {orig_w} the fixed output image size should be larger than original image"
|
| 152 |
+
max_size = [c, fixed_img_size[0], fixed_img_size[1]]
|
| 153 |
+
|
| 154 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
| 155 |
+
batch_shape = [len(tensor_list)] + max_size
|
| 156 |
+
b, c, h, w = batch_shape
|
| 157 |
+
dtype = tensor_list[0].dtype
|
| 158 |
+
device = tensor_list[0].device
|
| 159 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
| 160 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
| 161 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
| 162 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| 163 |
+
m[: img.shape[1], : img.shape[2]] = False
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError("not supported")
|
| 166 |
+
return NestedTensor(tensor, mask)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@torch.jit.unused
|
| 170 |
+
def _onnx_nested_tensor_from_tensor_list(
|
| 171 |
+
tensor_list: List[torch.Tensor],
|
| 172 |
+
) -> NestedTensor:
|
| 173 |
+
max_size = []
|
| 174 |
+
for i in range(tensor_list[0].dim()):
|
| 175 |
+
max_size_i = torch.max(
|
| 176 |
+
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
| 177 |
+
).to(torch.int64)
|
| 178 |
+
max_size.append(max_size_i)
|
| 179 |
+
max_size = tuple(max_size)
|
| 180 |
+
|
| 181 |
+
padded_imgs = []
|
| 182 |
+
padded_masks = []
|
| 183 |
+
for img in tensor_list:
|
| 184 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
| 185 |
+
padded_img = torch.nn.functional.pad(
|
| 186 |
+
img, (0, padding[2], 0, padding[1], 0, padding[0])
|
| 187 |
+
)
|
| 188 |
+
padded_imgs.append(padded_img)
|
| 189 |
+
|
| 190 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
| 191 |
+
padded_mask = torch.nn.functional.pad(
|
| 192 |
+
m, (0, padding[2], 0, padding[1]), "constant", 1
|
| 193 |
+
)
|
| 194 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
| 195 |
+
|
| 196 |
+
tensor = torch.stack(padded_imgs)
|
| 197 |
+
mask = torch.stack(padded_masks)
|
| 198 |
+
|
| 199 |
+
return NestedTensor(tensor, mask=mask)
|
detect_tools/upn/models/utils/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .detr_utils import (
|
| 2 |
+
PositionEmbeddingLearned,
|
| 3 |
+
PositionEmbeddingSine,
|
| 4 |
+
PositionEmbeddingSineHW,
|
| 5 |
+
clean_state_dict,
|
| 6 |
+
gen_encoder_output_proposals,
|
| 7 |
+
gen_sineembed_for_position,
|
| 8 |
+
get_activation_fn,
|
| 9 |
+
get_clones,
|
| 10 |
+
inverse_sigmoid,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"inverse_sigmoid",
|
| 15 |
+
"gen_encoder_output_proposals",
|
| 16 |
+
"get_clones",
|
| 17 |
+
"gen_sineembed_for_position",
|
| 18 |
+
"get_activation_fn",
|
| 19 |
+
"clean_state_dict",
|
| 20 |
+
"PositionEmbeddingSine",
|
| 21 |
+
"PositionEmbeddingSineHW",
|
| 22 |
+
"PositionEmbeddingLearned",
|
| 23 |
+
]
|
detect_tools/upn/models/utils/detr_utils.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from detect_tools.upn import POS_EMBEDDINGS
|
| 11 |
+
from detect_tools.upn.models.module import NestedTensor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@POS_EMBEDDINGS.register_module()
|
| 15 |
+
class PositionEmbeddingSine(nn.Module):
|
| 16 |
+
"""This is a more standard version of the position embedding, very similar to the one
|
| 17 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
num_pos_feats (int): The channel of positional embeddings.
|
| 21 |
+
temperature (float): The temperature used in positional embeddings.
|
| 22 |
+
normalize (bool): Whether to normalize the positional embeddings.
|
| 23 |
+
scale (float): The scale factor of positional embeddings.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
num_pos_feats: int = 64,
|
| 29 |
+
temperature: int = 10000,
|
| 30 |
+
normalize: bool = False,
|
| 31 |
+
scale: float = None,
|
| 32 |
+
) -> None:
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.num_pos_feats = num_pos_feats
|
| 35 |
+
self.temperature = temperature
|
| 36 |
+
self.normalize = normalize
|
| 37 |
+
if scale is not None and normalize is False:
|
| 38 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 39 |
+
if scale is None:
|
| 40 |
+
scale = 2 * math.pi
|
| 41 |
+
self.scale = scale
|
| 42 |
+
|
| 43 |
+
def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
|
| 44 |
+
"""Forward function.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
|
| 51 |
+
"""
|
| 52 |
+
x = tensor_list.tensors
|
| 53 |
+
mask = tensor_list.mask
|
| 54 |
+
assert mask is not None
|
| 55 |
+
not_mask = ~mask
|
| 56 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 57 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 58 |
+
if self.normalize:
|
| 59 |
+
eps = 1e-6
|
| 60 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 61 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 62 |
+
|
| 63 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 64 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 65 |
+
|
| 66 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 67 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 68 |
+
pos_x = torch.stack(
|
| 69 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 70 |
+
).flatten(3)
|
| 71 |
+
pos_y = torch.stack(
|
| 72 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 73 |
+
).flatten(3)
|
| 74 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 75 |
+
return pos
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@POS_EMBEDDINGS.register_module()
|
| 79 |
+
class PositionEmbeddingSineHW(nn.Module):
|
| 80 |
+
"""This is a more standard version of the position embedding, very similar to the one
|
| 81 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
num_pos_feats (int): The channel of positional embeddings.
|
| 85 |
+
temperatureH (float): The temperature used in positional embeddings.
|
| 86 |
+
temperatureW (float): The temperature used in positional embeddings.
|
| 87 |
+
normalize (bool): Whether to normalize the positional embeddings.
|
| 88 |
+
scale (float): The scale factor of positional embeddings.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
num_pos_feats: int = 64,
|
| 94 |
+
temperatureH: int = 10000,
|
| 95 |
+
temperatureW: int = 10000,
|
| 96 |
+
normalize: bool = False,
|
| 97 |
+
scale: float = None,
|
| 98 |
+
) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.num_pos_feats = num_pos_feats
|
| 101 |
+
self.temperatureH = temperatureH
|
| 102 |
+
self.temperatureW = temperatureW
|
| 103 |
+
self.normalize = normalize
|
| 104 |
+
if scale is not None and normalize is False:
|
| 105 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 106 |
+
if scale is None:
|
| 107 |
+
scale = 2 * math.pi
|
| 108 |
+
self.scale = scale
|
| 109 |
+
|
| 110 |
+
def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
|
| 111 |
+
"""Forward function.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
|
| 118 |
+
"""
|
| 119 |
+
x = tensor_list.tensors
|
| 120 |
+
mask = tensor_list.mask
|
| 121 |
+
assert mask is not None
|
| 122 |
+
not_mask = ~mask
|
| 123 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 124 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 125 |
+
|
| 126 |
+
if self.normalize:
|
| 127 |
+
eps = 1e-6
|
| 128 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 129 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 130 |
+
|
| 131 |
+
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 132 |
+
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
|
| 133 |
+
pos_x = x_embed[:, :, :, None] / dim_tx
|
| 134 |
+
|
| 135 |
+
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 136 |
+
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
|
| 137 |
+
pos_y = y_embed[:, :, :, None] / dim_ty
|
| 138 |
+
|
| 139 |
+
pos_x = torch.stack(
|
| 140 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 141 |
+
).flatten(3)
|
| 142 |
+
pos_y = torch.stack(
|
| 143 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 144 |
+
).flatten(3)
|
| 145 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 146 |
+
|
| 147 |
+
return pos
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@POS_EMBEDDINGS.register_module()
|
| 151 |
+
class PositionEmbeddingLearned(nn.Module):
|
| 152 |
+
"""Absolute pos embedding, learned.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
num_pos_feats (int): The channel dimension of positional embeddings.
|
| 156 |
+
num_row (int): The number of rows of the input feature map.
|
| 157 |
+
num_col (int): The number of columns of the input feature map.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self, num_row: int = 50, num_col: int = 50, num_pos_feats: int = 256
|
| 162 |
+
) -> None:
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.row_embed = nn.Embedding(num_row, num_pos_feats)
|
| 165 |
+
self.col_embed = nn.Embedding(num_col, num_pos_feats)
|
| 166 |
+
self.reset_parameters()
|
| 167 |
+
|
| 168 |
+
def reset_parameters(self):
|
| 169 |
+
nn.init.uniform_(self.row_embed.weight)
|
| 170 |
+
nn.init.uniform_(self.col_embed.weight)
|
| 171 |
+
|
| 172 |
+
def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
|
| 173 |
+
"""Forward function.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
|
| 180 |
+
"""
|
| 181 |
+
x = tensor_list.tensors
|
| 182 |
+
h, w = x.shape[-2:]
|
| 183 |
+
i = torch.arange(w, device=x.device)
|
| 184 |
+
j = torch.arange(h, device=x.device)
|
| 185 |
+
x_emb = self.col_embed(i)
|
| 186 |
+
y_emb = self.row_embed(j)
|
| 187 |
+
pos = (
|
| 188 |
+
torch.cat(
|
| 189 |
+
[
|
| 190 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
| 191 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
| 192 |
+
],
|
| 193 |
+
dim=-1,
|
| 194 |
+
)
|
| 195 |
+
.permute(2, 0, 1)
|
| 196 |
+
.unsqueeze(0)
|
| 197 |
+
.repeat(x.shape[0], 1, 1, 1)
|
| 198 |
+
)
|
| 199 |
+
return pos
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def build_position_encoding(args):
|
| 203 |
+
N_steps = args.hidden_dim // 2
|
| 204 |
+
if args.position_embedding in ("v2", "sine"):
|
| 205 |
+
# TODO find a better way of exposing other arguments
|
| 206 |
+
position_embedding = PositionEmbeddingSineHW(
|
| 207 |
+
N_steps,
|
| 208 |
+
temperatureH=args.pe_temperatureH,
|
| 209 |
+
temperatureW=args.pe_temperatureW,
|
| 210 |
+
normalize=True,
|
| 211 |
+
)
|
| 212 |
+
elif args.position_embedding in ("v3", "learned"):
|
| 213 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
| 216 |
+
|
| 217 |
+
return position_embedding
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def clean_state_dict(state_dict):
|
| 221 |
+
new_state_dict = OrderedDict()
|
| 222 |
+
for k, v in state_dict.items():
|
| 223 |
+
if k[:7] == "module.":
|
| 224 |
+
k = k[7:] # remove `module.`
|
| 225 |
+
new_state_dict[k] = v
|
| 226 |
+
return new_state_dict
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_activation_fn(activation: str, d_model: int = 256, batch_dim: int = 0):
|
| 230 |
+
"""Return an activation function given a string
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
activation (str): activation function name
|
| 234 |
+
d_model (int, optional): d_model. Defaults to 256.
|
| 235 |
+
batch_dim (int, optional): batch dimension. Defaults to 0.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
F: activation function
|
| 239 |
+
"""
|
| 240 |
+
if activation == "relu":
|
| 241 |
+
return F.relu
|
| 242 |
+
if activation == "gelu":
|
| 243 |
+
return F.gelu
|
| 244 |
+
if activation == "glu":
|
| 245 |
+
return F.glu
|
| 246 |
+
if activation == "prelu":
|
| 247 |
+
return nn.PReLU()
|
| 248 |
+
if activation == "selu":
|
| 249 |
+
return F.selu
|
| 250 |
+
|
| 251 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def get_clones(module: nn.Module, N: int, layer_share: bool = False):
|
| 255 |
+
"""Copy module N times
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
module (nn.Module): module to copy
|
| 259 |
+
N (int): number of copies
|
| 260 |
+
layer_share (bool, optional): share the same layer. If true, the modules will
|
| 261 |
+
share the same memory. Defaults to False.
|
| 262 |
+
"""
|
| 263 |
+
if layer_share:
|
| 264 |
+
return nn.ModuleList([module for _ in range(N)])
|
| 265 |
+
else:
|
| 266 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def inverse_sigmoid(x, eps=1e-3):
|
| 270 |
+
x = x.clamp(min=0, max=1)
|
| 271 |
+
x1 = x.clamp(min=eps)
|
| 272 |
+
x2 = (1 - x).clamp(min=eps)
|
| 273 |
+
return torch.log(x1 / x2)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def gen_sineembed_for_position(pos_tensor):
|
| 277 |
+
# n_query, bs, _ = pos_tensor.size()
|
| 278 |
+
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
| 279 |
+
scale = 2 * math.pi
|
| 280 |
+
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
| 281 |
+
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
|
| 282 |
+
x_embed = pos_tensor[:, :, 0] * scale
|
| 283 |
+
y_embed = pos_tensor[:, :, 1] * scale
|
| 284 |
+
pos_x = x_embed[:, :, None] / dim_t
|
| 285 |
+
pos_y = y_embed[:, :, None] / dim_t
|
| 286 |
+
pos_x = torch.stack(
|
| 287 |
+
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
|
| 288 |
+
).flatten(2)
|
| 289 |
+
pos_y = torch.stack(
|
| 290 |
+
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
|
| 291 |
+
).flatten(2)
|
| 292 |
+
if pos_tensor.size(-1) == 2:
|
| 293 |
+
pos = torch.cat((pos_y, pos_x), dim=2)
|
| 294 |
+
elif pos_tensor.size(-1) == 4:
|
| 295 |
+
w_embed = pos_tensor[:, :, 2] * scale
|
| 296 |
+
pos_w = w_embed[:, :, None] / dim_t
|
| 297 |
+
pos_w = torch.stack(
|
| 298 |
+
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
|
| 299 |
+
).flatten(2)
|
| 300 |
+
|
| 301 |
+
h_embed = pos_tensor[:, :, 3] * scale
|
| 302 |
+
pos_h = h_embed[:, :, None] / dim_t
|
| 303 |
+
pos_h = torch.stack(
|
| 304 |
+
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
|
| 305 |
+
).flatten(2)
|
| 306 |
+
|
| 307 |
+
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
| 308 |
+
else:
|
| 309 |
+
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
|
| 310 |
+
return pos
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_sine_pos_embed(
|
| 314 |
+
pos_tensor: torch.Tensor,
|
| 315 |
+
num_pos_feats: int = 128,
|
| 316 |
+
temperature: int = 10000,
|
| 317 |
+
exchange_xy: bool = True,
|
| 318 |
+
):
|
| 319 |
+
"""generate sine position embedding from a position tensor
|
| 320 |
+
Args:
|
| 321 |
+
pos_tensor (torch.Tensor): shape: [..., n].
|
| 322 |
+
num_pos_feats (int): projected shape for each float in the tensor.
|
| 323 |
+
temperature (int): temperature in the sine/cosine function.
|
| 324 |
+
exchange_xy (bool, optional): exchange pos x and pos y. \
|
| 325 |
+
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
|
| 326 |
+
Returns:
|
| 327 |
+
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
|
| 328 |
+
"""
|
| 329 |
+
scale = 2 * math.pi
|
| 330 |
+
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
|
| 331 |
+
dim_t = temperature ** (
|
| 332 |
+
2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def sine_func(x: torch.Tensor):
|
| 336 |
+
sin_x = x * scale / dim_t
|
| 337 |
+
sin_x = torch.stack(
|
| 338 |
+
(sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3
|
| 339 |
+
).flatten(2)
|
| 340 |
+
return sin_x
|
| 341 |
+
|
| 342 |
+
pos_res = [
|
| 343 |
+
sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)
|
| 344 |
+
]
|
| 345 |
+
if exchange_xy:
|
| 346 |
+
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
|
| 347 |
+
pos_res = torch.cat(pos_res, dim=-1)
|
| 348 |
+
return pos_res
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def gen_encoder_output_proposals(
|
| 352 |
+
memory: torch.Tensor,
|
| 353 |
+
memory_padding_mask: torch.Tensor,
|
| 354 |
+
spatial_shapes: torch.Tensor,
|
| 355 |
+
learnedwh=None,
|
| 356 |
+
):
|
| 357 |
+
"""
|
| 358 |
+
Input:
|
| 359 |
+
- memory: bs, \sum{hw}, d_model
|
| 360 |
+
- memory_padding_mask: bs, \sum{hw}
|
| 361 |
+
- spatial_shapes: nlevel, 2
|
| 362 |
+
- learnedwh: 2
|
| 363 |
+
Output:
|
| 364 |
+
- output_memory: bs, \sum{hw}, d_model
|
| 365 |
+
- output_proposals: bs, \sum{hw}, 4
|
| 366 |
+
"""
|
| 367 |
+
N_, S_, C_ = memory.shape
|
| 368 |
+
base_scale = 4.0
|
| 369 |
+
proposals = []
|
| 370 |
+
_cur = 0
|
| 371 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
| 372 |
+
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
|
| 373 |
+
N_, H_, W_, 1
|
| 374 |
+
)
|
| 375 |
+
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
| 376 |
+
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
| 377 |
+
grid_y, grid_x = torch.meshgrid(
|
| 378 |
+
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
| 379 |
+
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
| 380 |
+
)
|
| 381 |
+
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
| 382 |
+
|
| 383 |
+
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
|
| 384 |
+
N_, 1, 1, 2
|
| 385 |
+
)
|
| 386 |
+
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
| 387 |
+
|
| 388 |
+
if learnedwh is not None:
|
| 389 |
+
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
|
| 390 |
+
else:
|
| 391 |
+
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
|
| 392 |
+
|
| 393 |
+
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
| 394 |
+
proposals.append(proposal)
|
| 395 |
+
_cur += H_ * W_
|
| 396 |
+
|
| 397 |
+
output_proposals = torch.cat(proposals, 1)
|
| 398 |
+
output_proposals_valid = (
|
| 399 |
+
(output_proposals > 0.01) & (output_proposals < 0.99)
|
| 400 |
+
).all(-1, keepdim=True)
|
| 401 |
+
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
|
| 402 |
+
output_proposals = output_proposals.masked_fill(
|
| 403 |
+
memory_padding_mask.unsqueeze(-1), float("inf")
|
| 404 |
+
)
|
| 405 |
+
output_proposals = output_proposals.masked_fill(
|
| 406 |
+
~output_proposals_valid, float("inf")
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
output_memory = memory
|
| 410 |
+
output_memory = output_memory.masked_fill(
|
| 411 |
+
memory_padding_mask.unsqueeze(-1), float(0)
|
| 412 |
+
)
|
| 413 |
+
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
| 414 |
+
|
| 415 |
+
return output_memory, output_proposals
|
detect_tools/upn/ops/functions/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from .ms_deform_attn_func import MSDeformAttnFunction
|
| 10 |
+
|
detect_tools/upn/ops/functions/ms_deform_attn_func.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
from __future__ import division
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.autograd import Function
|
| 16 |
+
from torch.autograd.function import once_differentiable
|
| 17 |
+
|
| 18 |
+
import MultiScaleDeformableAttention as MSDA
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MSDeformAttnFunction(Function):
|
| 22 |
+
@staticmethod
|
| 23 |
+
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
|
| 24 |
+
ctx.im2col_step = im2col_step
|
| 25 |
+
output = MSDA.ms_deform_attn_forward(
|
| 26 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
|
| 27 |
+
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
|
| 28 |
+
return output
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
@once_differentiable
|
| 32 |
+
def backward(ctx, grad_output):
|
| 33 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
| 34 |
+
grad_value, grad_sampling_loc, grad_attn_weight = \
|
| 35 |
+
MSDA.ms_deform_attn_backward(
|
| 36 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
|
| 37 |
+
|
| 38 |
+
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
| 42 |
+
# for debug and test only,
|
| 43 |
+
# need to use cuda version instead
|
| 44 |
+
N_, S_, M_, D_ = value.shape
|
| 45 |
+
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
| 46 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
| 47 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 48 |
+
sampling_value_list = []
|
| 49 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
| 50 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
| 51 |
+
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
| 52 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
| 53 |
+
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
| 54 |
+
# N_*M_, D_, Lq_, P_
|
| 55 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
| 56 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
| 57 |
+
sampling_value_list.append(sampling_value_l_)
|
| 58 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
| 59 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
|
| 60 |
+
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
| 61 |
+
return output.transpose(1, 2).contiguous()
|
detect_tools/upn/ops/modules/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from .ms_deform_attn import MSDeformAttn
|
detect_tools/upn/ops/modules/ms_deform_attn.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
from __future__ import division
|
| 12 |
+
|
| 13 |
+
import warnings
|
| 14 |
+
import math, os
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn.init import xavier_uniform_, constant_
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from ..functions import MSDeformAttnFunction
|
| 23 |
+
except:
|
| 24 |
+
warnings.warn("Failed to import MSDeformAttnFunction.")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _is_power_of_2(n):
|
| 28 |
+
if (not isinstance(n, int)) or (n < 0):
|
| 29 |
+
raise ValueError(
|
| 30 |
+
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
|
| 31 |
+
)
|
| 32 |
+
return (n & (n - 1) == 0) and n != 0
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MSDeformAttn(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Multi-Scale Deformable Attention Module
|
| 41 |
+
:param d_model hidden dimension
|
| 42 |
+
:param n_levels number of feature levels
|
| 43 |
+
:param n_heads number of attention heads
|
| 44 |
+
:param n_points number of sampling points per attention head per feature level
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
if d_model % n_heads != 0:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"d_model must be divisible by n_heads, but got {} and {}".format(
|
| 50 |
+
d_model, n_heads
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
_d_per_head = d_model // n_heads
|
| 54 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
| 55 |
+
if not _is_power_of_2(_d_per_head):
|
| 56 |
+
warnings.warn(
|
| 57 |
+
"You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
| 58 |
+
"which is more efficient in our CUDA implementation."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.im2col_step = 64
|
| 62 |
+
|
| 63 |
+
self.d_model = d_model
|
| 64 |
+
self.n_levels = n_levels
|
| 65 |
+
self.n_heads = n_heads
|
| 66 |
+
self.n_points = n_points
|
| 67 |
+
|
| 68 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
| 69 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
| 70 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
| 71 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
| 72 |
+
|
| 73 |
+
self.use_4D_normalizer = use_4D_normalizer
|
| 74 |
+
|
| 75 |
+
self._reset_parameters()
|
| 76 |
+
|
| 77 |
+
def _reset_parameters(self):
|
| 78 |
+
constant_(self.sampling_offsets.weight.data, 0.0)
|
| 79 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
|
| 80 |
+
2.0 * math.pi / self.n_heads
|
| 81 |
+
)
|
| 82 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
| 83 |
+
grid_init = (
|
| 84 |
+
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
| 85 |
+
.view(self.n_heads, 1, 1, 2)
|
| 86 |
+
.repeat(1, self.n_levels, self.n_points, 1)
|
| 87 |
+
)
|
| 88 |
+
for i in range(self.n_points):
|
| 89 |
+
grid_init[:, :, i, :] *= i + 1
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
| 92 |
+
constant_(self.attention_weights.weight.data, 0.0)
|
| 93 |
+
constant_(self.attention_weights.bias.data, 0.0)
|
| 94 |
+
xavier_uniform_(self.value_proj.weight.data)
|
| 95 |
+
constant_(self.value_proj.bias.data, 0.0)
|
| 96 |
+
xavier_uniform_(self.output_proj.weight.data)
|
| 97 |
+
constant_(self.output_proj.bias.data, 0.0)
|
| 98 |
+
|
| 99 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
query,
|
| 103 |
+
reference_points,
|
| 104 |
+
input_flatten,
|
| 105 |
+
input_spatial_shapes,
|
| 106 |
+
input_level_start_index,
|
| 107 |
+
input_padding_mask=None,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
:param query (N, Length_{query}, C)
|
| 111 |
+
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
| 112 |
+
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
| 113 |
+
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
| 114 |
+
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
| 115 |
+
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
| 116 |
+
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
|
| 117 |
+
|
| 118 |
+
:return output (N, Length_{query}, C)
|
| 119 |
+
"""
|
| 120 |
+
N, Len_q, _ = query.shape
|
| 121 |
+
N, Len_in, _ = input_flatten.shape
|
| 122 |
+
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
| 123 |
+
|
| 124 |
+
value = self.value_proj(input_flatten)
|
| 125 |
+
if input_padding_mask is not None:
|
| 126 |
+
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
| 127 |
+
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
| 128 |
+
sampling_offsets = self.sampling_offsets(query).view(
|
| 129 |
+
N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
|
| 130 |
+
)
|
| 131 |
+
attention_weights = self.attention_weights(query).view(
|
| 132 |
+
N, Len_q, self.n_heads, self.n_levels * self.n_points
|
| 133 |
+
)
|
| 134 |
+
attention_weights = F.softmax(attention_weights, -1).view(
|
| 135 |
+
N, Len_q, self.n_heads, self.n_levels, self.n_points
|
| 136 |
+
)
|
| 137 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
| 138 |
+
|
| 139 |
+
# if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
|
| 140 |
+
# import ipdb; ipdb.set_trace()
|
| 141 |
+
|
| 142 |
+
if reference_points.shape[-1] == 2:
|
| 143 |
+
offset_normalizer = torch.stack(
|
| 144 |
+
[input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
|
| 145 |
+
)
|
| 146 |
+
sampling_locations = (
|
| 147 |
+
reference_points[:, :, None, :, None, :]
|
| 148 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 149 |
+
)
|
| 150 |
+
elif reference_points.shape[-1] == 4:
|
| 151 |
+
if self.use_4D_normalizer:
|
| 152 |
+
offset_normalizer = torch.stack(
|
| 153 |
+
[input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
|
| 154 |
+
)
|
| 155 |
+
sampling_locations = (
|
| 156 |
+
reference_points[:, :, None, :, None, :2]
|
| 157 |
+
+ sampling_offsets
|
| 158 |
+
/ offset_normalizer[None, None, None, :, None, :]
|
| 159 |
+
* reference_points[:, :, None, :, None, 2:]
|
| 160 |
+
* 0.5
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
sampling_locations = (
|
| 164 |
+
reference_points[:, :, None, :, None, :2]
|
| 165 |
+
+ sampling_offsets
|
| 166 |
+
/ self.n_points
|
| 167 |
+
* reference_points[:, :, None, :, None, 2:]
|
| 168 |
+
* 0.5
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
|
| 173 |
+
reference_points.shape[-1]
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
|
| 178 |
+
# import ipdb; ipdb.set_trace()
|
| 179 |
+
|
| 180 |
+
# for amp
|
| 181 |
+
if value.dtype == torch.float16:
|
| 182 |
+
# for mixed precision
|
| 183 |
+
output = MSDeformAttnFunction.apply(
|
| 184 |
+
value.to(torch.float32),
|
| 185 |
+
input_spatial_shapes,
|
| 186 |
+
input_level_start_index,
|
| 187 |
+
sampling_locations.to(torch.float32),
|
| 188 |
+
attention_weights,
|
| 189 |
+
self.im2col_step,
|
| 190 |
+
)
|
| 191 |
+
output = output.to(torch.float16)
|
| 192 |
+
output = self.output_proj(output)
|
| 193 |
+
return output
|
| 194 |
+
|
| 195 |
+
output = MSDeformAttnFunction.apply(
|
| 196 |
+
value,
|
| 197 |
+
input_spatial_shapes,
|
| 198 |
+
input_level_start_index,
|
| 199 |
+
sampling_locations,
|
| 200 |
+
attention_weights,
|
| 201 |
+
self.im2col_step,
|
| 202 |
+
)
|
| 203 |
+
output = self.output_proj(output)
|
| 204 |
+
return output
|
detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
from __future__ import division
|
| 12 |
+
|
| 13 |
+
import warnings
|
| 14 |
+
import math, os
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn.init import xavier_uniform_, constant_
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from ..functions import MSDeformAttnFunction
|
| 23 |
+
except:
|
| 24 |
+
warnings.warn('Failed to import MSDeformAttnFunction.')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _is_power_of_2(n):
|
| 28 |
+
if (not isinstance(n, int)) or (n < 0):
|
| 29 |
+
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
| 30 |
+
return (n & (n-1) == 0) and n != 0
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MSDeformAttn(nn.Module):
|
| 34 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
|
| 35 |
+
"""
|
| 36 |
+
Multi-Scale Deformable Attention Module
|
| 37 |
+
:param d_model hidden dimension
|
| 38 |
+
:param n_levels number of feature levels
|
| 39 |
+
:param n_heads number of attention heads
|
| 40 |
+
:param n_points number of sampling points per attention head per feature level
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
if d_model % n_heads != 0:
|
| 44 |
+
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
| 45 |
+
_d_per_head = d_model // n_heads
|
| 46 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
| 47 |
+
if not _is_power_of_2(_d_per_head):
|
| 48 |
+
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
| 49 |
+
"which is more efficient in our CUDA implementation.")
|
| 50 |
+
|
| 51 |
+
self.im2col_step = 64
|
| 52 |
+
|
| 53 |
+
self.d_model = d_model
|
| 54 |
+
self.n_levels = n_levels
|
| 55 |
+
self.n_heads = n_heads
|
| 56 |
+
self.n_points = n_points
|
| 57 |
+
|
| 58 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
| 59 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
| 60 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
| 61 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
| 62 |
+
|
| 63 |
+
self.use_4D_normalizer = use_4D_normalizer
|
| 64 |
+
|
| 65 |
+
self._reset_parameters()
|
| 66 |
+
|
| 67 |
+
def _reset_parameters(self):
|
| 68 |
+
constant_(self.sampling_offsets.weight.data, 0.)
|
| 69 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
| 70 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
| 71 |
+
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
|
| 72 |
+
for i in range(self.n_points):
|
| 73 |
+
grid_init[:, :, i, :] *= i + 1
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
| 76 |
+
constant_(self.attention_weights.weight.data, 0.)
|
| 77 |
+
constant_(self.attention_weights.bias.data, 0.)
|
| 78 |
+
xavier_uniform_(self.value_proj.weight.data)
|
| 79 |
+
constant_(self.value_proj.bias.data, 0.)
|
| 80 |
+
xavier_uniform_(self.output_proj.weight.data)
|
| 81 |
+
constant_(self.output_proj.bias.data, 0.)
|
| 82 |
+
|
| 83 |
+
def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
|
| 84 |
+
"""
|
| 85 |
+
:param query (N, Length_{query}, C)
|
| 86 |
+
:param key (N, 1, C)
|
| 87 |
+
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
| 88 |
+
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
| 89 |
+
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
| 90 |
+
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
| 91 |
+
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
| 92 |
+
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
|
| 93 |
+
|
| 94 |
+
:return output (N, Length_{query}, C)
|
| 95 |
+
"""
|
| 96 |
+
N, Len_q, _ = query.shape
|
| 97 |
+
N, Len_in, _ = input_flatten.shape
|
| 98 |
+
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
| 99 |
+
|
| 100 |
+
value = self.value_proj(input_flatten)
|
| 101 |
+
if input_padding_mask is not None:
|
| 102 |
+
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
| 103 |
+
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
| 104 |
+
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
| 105 |
+
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
|
| 106 |
+
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
|
| 107 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
| 108 |
+
|
| 109 |
+
# if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
|
| 110 |
+
# import ipdb; ipdb.set_trace()
|
| 111 |
+
|
| 112 |
+
if reference_points.shape[-1] == 2:
|
| 113 |
+
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
| 114 |
+
sampling_locations = reference_points[:, :, None, :, None, :] \
|
| 115 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 116 |
+
elif reference_points.shape[-1] == 4:
|
| 117 |
+
if self.use_4D_normalizer:
|
| 118 |
+
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
| 119 |
+
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
| 120 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
|
| 121 |
+
else:
|
| 122 |
+
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
| 123 |
+
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
|
| 127 |
+
output = MSDeformAttnFunction.apply(
|
| 128 |
+
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
|
| 129 |
+
output = self.output_proj(output)
|
| 130 |
+
return output
|
detect_tools/upn/ops/setup.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
| 15 |
+
from torch.utils.cpp_extension import CppExtension
|
| 16 |
+
from torch.utils.cpp_extension import CUDAExtension
|
| 17 |
+
|
| 18 |
+
from setuptools import find_packages
|
| 19 |
+
from setuptools import setup
|
| 20 |
+
|
| 21 |
+
requirements = ["torch", "torchvision"]
|
| 22 |
+
|
| 23 |
+
def get_extensions():
|
| 24 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
extensions_dir = os.path.join(this_dir, "src")
|
| 26 |
+
|
| 27 |
+
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
| 28 |
+
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
|
| 29 |
+
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
|
| 30 |
+
|
| 31 |
+
sources = main_file + source_cpu
|
| 32 |
+
extension = CppExtension
|
| 33 |
+
extra_compile_args = {"cxx": []}
|
| 34 |
+
define_macros = []
|
| 35 |
+
|
| 36 |
+
# import ipdb; ipdb.set_trace()
|
| 37 |
+
|
| 38 |
+
if torch.cuda.is_available() and CUDA_HOME is not None:
|
| 39 |
+
extension = CUDAExtension
|
| 40 |
+
sources += source_cuda
|
| 41 |
+
define_macros += [("WITH_CUDA", None)]
|
| 42 |
+
extra_compile_args["nvcc"] = [
|
| 43 |
+
"-DCUDA_HAS_FP16=1",
|
| 44 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
| 45 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
| 46 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
| 47 |
+
]
|
| 48 |
+
else:
|
| 49 |
+
raise NotImplementedError('Cuda is not availabel')
|
| 50 |
+
|
| 51 |
+
sources = [os.path.join(extensions_dir, s) for s in sources]
|
| 52 |
+
include_dirs = [extensions_dir]
|
| 53 |
+
ext_modules = [
|
| 54 |
+
extension(
|
| 55 |
+
"MultiScaleDeformableAttention",
|
| 56 |
+
sources,
|
| 57 |
+
include_dirs=include_dirs,
|
| 58 |
+
define_macros=define_macros,
|
| 59 |
+
extra_compile_args=extra_compile_args,
|
| 60 |
+
)
|
| 61 |
+
]
|
| 62 |
+
return ext_modules
|
| 63 |
+
|
| 64 |
+
setup(
|
| 65 |
+
name="MultiScaleDeformableAttention",
|
| 66 |
+
version="1.0",
|
| 67 |
+
author="Weijie Su",
|
| 68 |
+
url="https://github.com/fundamentalvision/Deformable-DETR",
|
| 69 |
+
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
|
| 70 |
+
packages=find_packages(exclude=("configs", "tests",)),
|
| 71 |
+
ext_modules=get_extensions(),
|
| 72 |
+
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
| 73 |
+
)
|
detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
#include <ATen/ATen.h>
|
| 14 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
at::Tensor
|
| 18 |
+
ms_deform_attn_cpu_forward(
|
| 19 |
+
const at::Tensor &value,
|
| 20 |
+
const at::Tensor &spatial_shapes,
|
| 21 |
+
const at::Tensor &level_start_index,
|
| 22 |
+
const at::Tensor &sampling_loc,
|
| 23 |
+
const at::Tensor &attn_weight,
|
| 24 |
+
const int im2col_step)
|
| 25 |
+
{
|
| 26 |
+
AT_ERROR("Not implement on cpu");
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
std::vector<at::Tensor>
|
| 30 |
+
ms_deform_attn_cpu_backward(
|
| 31 |
+
const at::Tensor &value,
|
| 32 |
+
const at::Tensor &spatial_shapes,
|
| 33 |
+
const at::Tensor &level_start_index,
|
| 34 |
+
const at::Tensor &sampling_loc,
|
| 35 |
+
const at::Tensor &attn_weight,
|
| 36 |
+
const at::Tensor &grad_output,
|
| 37 |
+
const int im2col_step)
|
| 38 |
+
{
|
| 39 |
+
AT_ERROR("Not implement on cpu");
|
| 40 |
+
}
|
| 41 |
+
|
detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
|
| 14 |
+
at::Tensor
|
| 15 |
+
ms_deform_attn_cpu_forward(
|
| 16 |
+
const at::Tensor &value,
|
| 17 |
+
const at::Tensor &spatial_shapes,
|
| 18 |
+
const at::Tensor &level_start_index,
|
| 19 |
+
const at::Tensor &sampling_loc,
|
| 20 |
+
const at::Tensor &attn_weight,
|
| 21 |
+
const int im2col_step);
|
| 22 |
+
|
| 23 |
+
std::vector<at::Tensor>
|
| 24 |
+
ms_deform_attn_cpu_backward(
|
| 25 |
+
const at::Tensor &value,
|
| 26 |
+
const at::Tensor &spatial_shapes,
|
| 27 |
+
const at::Tensor &level_start_index,
|
| 28 |
+
const at::Tensor &sampling_loc,
|
| 29 |
+
const at::Tensor &attn_weight,
|
| 30 |
+
const at::Tensor &grad_output,
|
| 31 |
+
const int im2col_step);
|
| 32 |
+
|
| 33 |
+
|
detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
#include "cuda/ms_deform_im2col_cuda.cuh"
|
| 13 |
+
|
| 14 |
+
#include <ATen/ATen.h>
|
| 15 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 16 |
+
#include <cuda.h>
|
| 17 |
+
#include <cuda_runtime.h>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 21 |
+
const at::Tensor &value,
|
| 22 |
+
const at::Tensor &spatial_shapes,
|
| 23 |
+
const at::Tensor &level_start_index,
|
| 24 |
+
const at::Tensor &sampling_loc,
|
| 25 |
+
const at::Tensor &attn_weight,
|
| 26 |
+
const int im2col_step)
|
| 27 |
+
{
|
| 28 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 29 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 30 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 31 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 32 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 33 |
+
|
| 34 |
+
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
|
| 35 |
+
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 36 |
+
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
|
| 37 |
+
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 38 |
+
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
| 39 |
+
|
| 40 |
+
const int batch = value.size(0);
|
| 41 |
+
const int spatial_size = value.size(1);
|
| 42 |
+
const int num_heads = value.size(2);
|
| 43 |
+
const int channels = value.size(3);
|
| 44 |
+
|
| 45 |
+
const int num_levels = spatial_shapes.size(0);
|
| 46 |
+
|
| 47 |
+
const int num_query = sampling_loc.size(1);
|
| 48 |
+
const int num_point = sampling_loc.size(4);
|
| 49 |
+
|
| 50 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 51 |
+
|
| 52 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 53 |
+
|
| 54 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
| 55 |
+
|
| 56 |
+
const int batch_n = im2col_step_;
|
| 57 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 58 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 59 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 60 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 61 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 62 |
+
{
|
| 63 |
+
auto columns = output_n.select(0, n);
|
| 64 |
+
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
|
| 65 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
| 66 |
+
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 67 |
+
spatial_shapes.data_ptr<int64_t>(),
|
| 68 |
+
level_start_index.data_ptr<int64_t>(),
|
| 69 |
+
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 70 |
+
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 71 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 72 |
+
columns.data_ptr<scalar_t>());
|
| 73 |
+
|
| 74 |
+
}));
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
output = output.view({batch, num_query, num_heads*channels});
|
| 78 |
+
|
| 79 |
+
return output;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 84 |
+
const at::Tensor &value,
|
| 85 |
+
const at::Tensor &spatial_shapes,
|
| 86 |
+
const at::Tensor &level_start_index,
|
| 87 |
+
const at::Tensor &sampling_loc,
|
| 88 |
+
const at::Tensor &attn_weight,
|
| 89 |
+
const at::Tensor &grad_output,
|
| 90 |
+
const int im2col_step)
|
| 91 |
+
{
|
| 92 |
+
|
| 93 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 94 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 95 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 96 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 97 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 98 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
| 99 |
+
|
| 100 |
+
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
|
| 101 |
+
AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 102 |
+
AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
|
| 103 |
+
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 104 |
+
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
| 105 |
+
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
|
| 106 |
+
|
| 107 |
+
const int batch = value.size(0);
|
| 108 |
+
const int spatial_size = value.size(1);
|
| 109 |
+
const int num_heads = value.size(2);
|
| 110 |
+
const int channels = value.size(3);
|
| 111 |
+
|
| 112 |
+
const int num_levels = spatial_shapes.size(0);
|
| 113 |
+
|
| 114 |
+
const int num_query = sampling_loc.size(1);
|
| 115 |
+
const int num_point = sampling_loc.size(4);
|
| 116 |
+
|
| 117 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 118 |
+
|
| 119 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 120 |
+
|
| 121 |
+
auto grad_value = at::zeros_like(value);
|
| 122 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
| 123 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
| 124 |
+
|
| 125 |
+
const int batch_n = im2col_step_;
|
| 126 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 127 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 128 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 129 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 130 |
+
|
| 131 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 132 |
+
{
|
| 133 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
| 134 |
+
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
|
| 135 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
| 136 |
+
grad_output_g.data_ptr<scalar_t>(),
|
| 137 |
+
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 138 |
+
spatial_shapes.data_ptr<int64_t>(),
|
| 139 |
+
level_start_index.data_ptr<int64_t>(),
|
| 140 |
+
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 141 |
+
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 142 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 143 |
+
grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 144 |
+
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 145 |
+
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
| 146 |
+
|
| 147 |
+
}));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
| 152 |
+
};
|
| 153 |
+
}
|
detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
|
| 14 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 15 |
+
const at::Tensor &value,
|
| 16 |
+
const at::Tensor &spatial_shapes,
|
| 17 |
+
const at::Tensor &level_start_index,
|
| 18 |
+
const at::Tensor &sampling_loc,
|
| 19 |
+
const at::Tensor &attn_weight,
|
| 20 |
+
const int im2col_step);
|
| 21 |
+
|
| 22 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 23 |
+
const at::Tensor &value,
|
| 24 |
+
const at::Tensor &spatial_shapes,
|
| 25 |
+
const at::Tensor &level_start_index,
|
| 26 |
+
const at::Tensor &sampling_loc,
|
| 27 |
+
const at::Tensor &attn_weight,
|
| 28 |
+
const at::Tensor &grad_output,
|
| 29 |
+
const int im2col_step);
|
| 30 |
+
|
detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh
ADDED
|
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************
|
| 7 |
+
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
|
| 8 |
+
* Copyright (c) 2018 Microsoft
|
| 9 |
+
**************************************************************************
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <cstring>
|
| 15 |
+
|
| 16 |
+
#include <ATen/ATen.h>
|
| 17 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 18 |
+
|
| 19 |
+
#include <THC/THCAtomics.cuh>
|
| 20 |
+
|
| 21 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 22 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 23 |
+
i < (n); \
|
| 24 |
+
i += blockDim.x * gridDim.x)
|
| 25 |
+
|
| 26 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 27 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
| 28 |
+
{
|
| 29 |
+
return (N + num_threads - 1) / num_threads;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
template <typename scalar_t>
|
| 34 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
| 35 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 36 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
| 37 |
+
{
|
| 38 |
+
const int h_low = floor(h);
|
| 39 |
+
const int w_low = floor(w);
|
| 40 |
+
const int h_high = h_low + 1;
|
| 41 |
+
const int w_high = w_low + 1;
|
| 42 |
+
|
| 43 |
+
const scalar_t lh = h - h_low;
|
| 44 |
+
const scalar_t lw = w - w_low;
|
| 45 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 46 |
+
|
| 47 |
+
const int w_stride = nheads * channels;
|
| 48 |
+
const int h_stride = width * w_stride;
|
| 49 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 50 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 51 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 52 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 53 |
+
const int base_ptr = m * channels + c;
|
| 54 |
+
|
| 55 |
+
scalar_t v1 = 0;
|
| 56 |
+
if (h_low >= 0 && w_low >= 0)
|
| 57 |
+
{
|
| 58 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 59 |
+
v1 = bottom_data[ptr1];
|
| 60 |
+
}
|
| 61 |
+
scalar_t v2 = 0;
|
| 62 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 63 |
+
{
|
| 64 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 65 |
+
v2 = bottom_data[ptr2];
|
| 66 |
+
}
|
| 67 |
+
scalar_t v3 = 0;
|
| 68 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 69 |
+
{
|
| 70 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 71 |
+
v3 = bottom_data[ptr3];
|
| 72 |
+
}
|
| 73 |
+
scalar_t v4 = 0;
|
| 74 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 75 |
+
{
|
| 76 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 77 |
+
v4 = bottom_data[ptr4];
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 81 |
+
|
| 82 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 83 |
+
return val;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
template <typename scalar_t>
|
| 88 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
| 89 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 90 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 91 |
+
const scalar_t &top_grad,
|
| 92 |
+
const scalar_t &attn_weight,
|
| 93 |
+
scalar_t* &grad_value,
|
| 94 |
+
scalar_t* grad_sampling_loc,
|
| 95 |
+
scalar_t* grad_attn_weight)
|
| 96 |
+
{
|
| 97 |
+
const int h_low = floor(h);
|
| 98 |
+
const int w_low = floor(w);
|
| 99 |
+
const int h_high = h_low + 1;
|
| 100 |
+
const int w_high = w_low + 1;
|
| 101 |
+
|
| 102 |
+
const scalar_t lh = h - h_low;
|
| 103 |
+
const scalar_t lw = w - w_low;
|
| 104 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 105 |
+
|
| 106 |
+
const int w_stride = nheads * channels;
|
| 107 |
+
const int h_stride = width * w_stride;
|
| 108 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 109 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 110 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 111 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 112 |
+
const int base_ptr = m * channels + c;
|
| 113 |
+
|
| 114 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 115 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 116 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 117 |
+
|
| 118 |
+
scalar_t v1 = 0;
|
| 119 |
+
if (h_low >= 0 && w_low >= 0)
|
| 120 |
+
{
|
| 121 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 122 |
+
v1 = bottom_data[ptr1];
|
| 123 |
+
grad_h_weight -= hw * v1;
|
| 124 |
+
grad_w_weight -= hh * v1;
|
| 125 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 126 |
+
}
|
| 127 |
+
scalar_t v2 = 0;
|
| 128 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 129 |
+
{
|
| 130 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 131 |
+
v2 = bottom_data[ptr2];
|
| 132 |
+
grad_h_weight -= lw * v2;
|
| 133 |
+
grad_w_weight += hh * v2;
|
| 134 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 135 |
+
}
|
| 136 |
+
scalar_t v3 = 0;
|
| 137 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 138 |
+
{
|
| 139 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 140 |
+
v3 = bottom_data[ptr3];
|
| 141 |
+
grad_h_weight += hw * v3;
|
| 142 |
+
grad_w_weight -= lh * v3;
|
| 143 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 144 |
+
}
|
| 145 |
+
scalar_t v4 = 0;
|
| 146 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 147 |
+
{
|
| 148 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 149 |
+
v4 = bottom_data[ptr4];
|
| 150 |
+
grad_h_weight += lw * v4;
|
| 151 |
+
grad_w_weight += lh * v4;
|
| 152 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 156 |
+
*grad_attn_weight = top_grad * val;
|
| 157 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
| 158 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
| 164 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 165 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 166 |
+
const scalar_t &top_grad,
|
| 167 |
+
const scalar_t &attn_weight,
|
| 168 |
+
scalar_t* &grad_value,
|
| 169 |
+
scalar_t* grad_sampling_loc,
|
| 170 |
+
scalar_t* grad_attn_weight)
|
| 171 |
+
{
|
| 172 |
+
const int h_low = floor(h);
|
| 173 |
+
const int w_low = floor(w);
|
| 174 |
+
const int h_high = h_low + 1;
|
| 175 |
+
const int w_high = w_low + 1;
|
| 176 |
+
|
| 177 |
+
const scalar_t lh = h - h_low;
|
| 178 |
+
const scalar_t lw = w - w_low;
|
| 179 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 180 |
+
|
| 181 |
+
const int w_stride = nheads * channels;
|
| 182 |
+
const int h_stride = width * w_stride;
|
| 183 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 184 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 185 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 186 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 187 |
+
const int base_ptr = m * channels + c;
|
| 188 |
+
|
| 189 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 190 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 191 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 192 |
+
|
| 193 |
+
scalar_t v1 = 0;
|
| 194 |
+
if (h_low >= 0 && w_low >= 0)
|
| 195 |
+
{
|
| 196 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 197 |
+
v1 = bottom_data[ptr1];
|
| 198 |
+
grad_h_weight -= hw * v1;
|
| 199 |
+
grad_w_weight -= hh * v1;
|
| 200 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 201 |
+
}
|
| 202 |
+
scalar_t v2 = 0;
|
| 203 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 204 |
+
{
|
| 205 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 206 |
+
v2 = bottom_data[ptr2];
|
| 207 |
+
grad_h_weight -= lw * v2;
|
| 208 |
+
grad_w_weight += hh * v2;
|
| 209 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 210 |
+
}
|
| 211 |
+
scalar_t v3 = 0;
|
| 212 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 213 |
+
{
|
| 214 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 215 |
+
v3 = bottom_data[ptr3];
|
| 216 |
+
grad_h_weight += hw * v3;
|
| 217 |
+
grad_w_weight -= lh * v3;
|
| 218 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 219 |
+
}
|
| 220 |
+
scalar_t v4 = 0;
|
| 221 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 222 |
+
{
|
| 223 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 224 |
+
v4 = bottom_data[ptr4];
|
| 225 |
+
grad_h_weight += lw * v4;
|
| 226 |
+
grad_w_weight += lh * v4;
|
| 227 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 231 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
| 232 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
| 233 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
template <typename scalar_t>
|
| 238 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
| 239 |
+
const scalar_t *data_value,
|
| 240 |
+
const int64_t *data_spatial_shapes,
|
| 241 |
+
const int64_t *data_level_start_index,
|
| 242 |
+
const scalar_t *data_sampling_loc,
|
| 243 |
+
const scalar_t *data_attn_weight,
|
| 244 |
+
const int batch_size,
|
| 245 |
+
const int spatial_size,
|
| 246 |
+
const int num_heads,
|
| 247 |
+
const int channels,
|
| 248 |
+
const int num_levels,
|
| 249 |
+
const int num_query,
|
| 250 |
+
const int num_point,
|
| 251 |
+
scalar_t *data_col)
|
| 252 |
+
{
|
| 253 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 254 |
+
{
|
| 255 |
+
int _temp = index;
|
| 256 |
+
const int c_col = _temp % channels;
|
| 257 |
+
_temp /= channels;
|
| 258 |
+
const int sampling_index = _temp;
|
| 259 |
+
const int m_col = _temp % num_heads;
|
| 260 |
+
_temp /= num_heads;
|
| 261 |
+
const int q_col = _temp % num_query;
|
| 262 |
+
_temp /= num_query;
|
| 263 |
+
const int b_col = _temp;
|
| 264 |
+
|
| 265 |
+
scalar_t *data_col_ptr = data_col + index;
|
| 266 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 267 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 268 |
+
const int qid_stride = num_heads * channels;
|
| 269 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 270 |
+
scalar_t col = 0;
|
| 271 |
+
|
| 272 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 273 |
+
{
|
| 274 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 275 |
+
const int spatial_h_ptr = l_col << 1;
|
| 276 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 277 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 278 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
| 279 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 280 |
+
{
|
| 281 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 282 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 283 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 284 |
+
|
| 285 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 286 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 287 |
+
|
| 288 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 289 |
+
{
|
| 290 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
data_weight_ptr += 1;
|
| 294 |
+
data_loc_w_ptr += 2;
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
*data_col_ptr = col;
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 302 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
| 303 |
+
const scalar_t *grad_col,
|
| 304 |
+
const scalar_t *data_value,
|
| 305 |
+
const int64_t *data_spatial_shapes,
|
| 306 |
+
const int64_t *data_level_start_index,
|
| 307 |
+
const scalar_t *data_sampling_loc,
|
| 308 |
+
const scalar_t *data_attn_weight,
|
| 309 |
+
const int batch_size,
|
| 310 |
+
const int spatial_size,
|
| 311 |
+
const int num_heads,
|
| 312 |
+
const int channels,
|
| 313 |
+
const int num_levels,
|
| 314 |
+
const int num_query,
|
| 315 |
+
const int num_point,
|
| 316 |
+
scalar_t *grad_value,
|
| 317 |
+
scalar_t *grad_sampling_loc,
|
| 318 |
+
scalar_t *grad_attn_weight)
|
| 319 |
+
{
|
| 320 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 321 |
+
{
|
| 322 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 323 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 324 |
+
unsigned int tid = threadIdx.x;
|
| 325 |
+
int _temp = index;
|
| 326 |
+
const int c_col = _temp % channels;
|
| 327 |
+
_temp /= channels;
|
| 328 |
+
const int sampling_index = _temp;
|
| 329 |
+
const int m_col = _temp % num_heads;
|
| 330 |
+
_temp /= num_heads;
|
| 331 |
+
const int q_col = _temp % num_query;
|
| 332 |
+
_temp /= num_query;
|
| 333 |
+
const int b_col = _temp;
|
| 334 |
+
|
| 335 |
+
const scalar_t top_grad = grad_col[index];
|
| 336 |
+
|
| 337 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 338 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 339 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 340 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 341 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 342 |
+
const int grad_weight_stride = 1;
|
| 343 |
+
const int grad_loc_stride = 2;
|
| 344 |
+
const int qid_stride = num_heads * channels;
|
| 345 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 346 |
+
|
| 347 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 348 |
+
{
|
| 349 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 350 |
+
const int spatial_h_ptr = l_col << 1;
|
| 351 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 352 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 353 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 354 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 355 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 356 |
+
|
| 357 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 358 |
+
{
|
| 359 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 360 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 361 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 362 |
+
|
| 363 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 364 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 365 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 366 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 367 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 368 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 369 |
+
{
|
| 370 |
+
ms_deform_attn_col2im_bilinear(
|
| 371 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 372 |
+
top_grad, weight, grad_value_ptr,
|
| 373 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
__syncthreads();
|
| 377 |
+
if (tid == 0)
|
| 378 |
+
{
|
| 379 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 380 |
+
int sid=2;
|
| 381 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
| 382 |
+
{
|
| 383 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 384 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 385 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 386 |
+
sid += 2;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
*grad_sampling_loc = _grad_w;
|
| 391 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 392 |
+
*grad_attn_weight = _grad_a;
|
| 393 |
+
}
|
| 394 |
+
__syncthreads();
|
| 395 |
+
|
| 396 |
+
data_weight_ptr += 1;
|
| 397 |
+
data_loc_w_ptr += 2;
|
| 398 |
+
grad_attn_weight += grad_weight_stride;
|
| 399 |
+
grad_sampling_loc += grad_loc_stride;
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 407 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
| 408 |
+
const scalar_t *grad_col,
|
| 409 |
+
const scalar_t *data_value,
|
| 410 |
+
const int64_t *data_spatial_shapes,
|
| 411 |
+
const int64_t *data_level_start_index,
|
| 412 |
+
const scalar_t *data_sampling_loc,
|
| 413 |
+
const scalar_t *data_attn_weight,
|
| 414 |
+
const int batch_size,
|
| 415 |
+
const int spatial_size,
|
| 416 |
+
const int num_heads,
|
| 417 |
+
const int channels,
|
| 418 |
+
const int num_levels,
|
| 419 |
+
const int num_query,
|
| 420 |
+
const int num_point,
|
| 421 |
+
scalar_t *grad_value,
|
| 422 |
+
scalar_t *grad_sampling_loc,
|
| 423 |
+
scalar_t *grad_attn_weight)
|
| 424 |
+
{
|
| 425 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 426 |
+
{
|
| 427 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 428 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 429 |
+
unsigned int tid = threadIdx.x;
|
| 430 |
+
int _temp = index;
|
| 431 |
+
const int c_col = _temp % channels;
|
| 432 |
+
_temp /= channels;
|
| 433 |
+
const int sampling_index = _temp;
|
| 434 |
+
const int m_col = _temp % num_heads;
|
| 435 |
+
_temp /= num_heads;
|
| 436 |
+
const int q_col = _temp % num_query;
|
| 437 |
+
_temp /= num_query;
|
| 438 |
+
const int b_col = _temp;
|
| 439 |
+
|
| 440 |
+
const scalar_t top_grad = grad_col[index];
|
| 441 |
+
|
| 442 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 443 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 444 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 445 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 446 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 447 |
+
const int grad_weight_stride = 1;
|
| 448 |
+
const int grad_loc_stride = 2;
|
| 449 |
+
const int qid_stride = num_heads * channels;
|
| 450 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 451 |
+
|
| 452 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 453 |
+
{
|
| 454 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 455 |
+
const int spatial_h_ptr = l_col << 1;
|
| 456 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 457 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 458 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 459 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 460 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 461 |
+
|
| 462 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 463 |
+
{
|
| 464 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 465 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 466 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 467 |
+
|
| 468 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 469 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 470 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 471 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 472 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 473 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 474 |
+
{
|
| 475 |
+
ms_deform_attn_col2im_bilinear(
|
| 476 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 477 |
+
top_grad, weight, grad_value_ptr,
|
| 478 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
__syncthreads();
|
| 482 |
+
|
| 483 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
| 484 |
+
{
|
| 485 |
+
if (tid < s) {
|
| 486 |
+
const unsigned int xid1 = tid << 1;
|
| 487 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 488 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 489 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 490 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 491 |
+
}
|
| 492 |
+
__syncthreads();
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
if (tid == 0)
|
| 496 |
+
{
|
| 497 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 498 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 499 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 500 |
+
}
|
| 501 |
+
__syncthreads();
|
| 502 |
+
|
| 503 |
+
data_weight_ptr += 1;
|
| 504 |
+
data_loc_w_ptr += 2;
|
| 505 |
+
grad_attn_weight += grad_weight_stride;
|
| 506 |
+
grad_sampling_loc += grad_loc_stride;
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
template <typename scalar_t>
|
| 514 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
| 515 |
+
const scalar_t *grad_col,
|
| 516 |
+
const scalar_t *data_value,
|
| 517 |
+
const int64_t *data_spatial_shapes,
|
| 518 |
+
const int64_t *data_level_start_index,
|
| 519 |
+
const scalar_t *data_sampling_loc,
|
| 520 |
+
const scalar_t *data_attn_weight,
|
| 521 |
+
const int batch_size,
|
| 522 |
+
const int spatial_size,
|
| 523 |
+
const int num_heads,
|
| 524 |
+
const int channels,
|
| 525 |
+
const int num_levels,
|
| 526 |
+
const int num_query,
|
| 527 |
+
const int num_point,
|
| 528 |
+
scalar_t *grad_value,
|
| 529 |
+
scalar_t *grad_sampling_loc,
|
| 530 |
+
scalar_t *grad_attn_weight)
|
| 531 |
+
{
|
| 532 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 533 |
+
{
|
| 534 |
+
extern __shared__ int _s[];
|
| 535 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 536 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 537 |
+
unsigned int tid = threadIdx.x;
|
| 538 |
+
int _temp = index;
|
| 539 |
+
const int c_col = _temp % channels;
|
| 540 |
+
_temp /= channels;
|
| 541 |
+
const int sampling_index = _temp;
|
| 542 |
+
const int m_col = _temp % num_heads;
|
| 543 |
+
_temp /= num_heads;
|
| 544 |
+
const int q_col = _temp % num_query;
|
| 545 |
+
_temp /= num_query;
|
| 546 |
+
const int b_col = _temp;
|
| 547 |
+
|
| 548 |
+
const scalar_t top_grad = grad_col[index];
|
| 549 |
+
|
| 550 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 551 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 552 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 553 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 554 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 555 |
+
const int grad_weight_stride = 1;
|
| 556 |
+
const int grad_loc_stride = 2;
|
| 557 |
+
const int qid_stride = num_heads * channels;
|
| 558 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 559 |
+
|
| 560 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 561 |
+
{
|
| 562 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 563 |
+
const int spatial_h_ptr = l_col << 1;
|
| 564 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 565 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 566 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 567 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 568 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 569 |
+
|
| 570 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 571 |
+
{
|
| 572 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 573 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 574 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 575 |
+
|
| 576 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 577 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 578 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 579 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 580 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 581 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 582 |
+
{
|
| 583 |
+
ms_deform_attn_col2im_bilinear(
|
| 584 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 585 |
+
top_grad, weight, grad_value_ptr,
|
| 586 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
__syncthreads();
|
| 590 |
+
if (tid == 0)
|
| 591 |
+
{
|
| 592 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 593 |
+
int sid=2;
|
| 594 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
| 595 |
+
{
|
| 596 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 597 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 598 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 599 |
+
sid += 2;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
*grad_sampling_loc = _grad_w;
|
| 604 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 605 |
+
*grad_attn_weight = _grad_a;
|
| 606 |
+
}
|
| 607 |
+
__syncthreads();
|
| 608 |
+
|
| 609 |
+
data_weight_ptr += 1;
|
| 610 |
+
data_loc_w_ptr += 2;
|
| 611 |
+
grad_attn_weight += grad_weight_stride;
|
| 612 |
+
grad_sampling_loc += grad_loc_stride;
|
| 613 |
+
}
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
template <typename scalar_t>
|
| 619 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
| 620 |
+
const scalar_t *grad_col,
|
| 621 |
+
const scalar_t *data_value,
|
| 622 |
+
const int64_t *data_spatial_shapes,
|
| 623 |
+
const int64_t *data_level_start_index,
|
| 624 |
+
const scalar_t *data_sampling_loc,
|
| 625 |
+
const scalar_t *data_attn_weight,
|
| 626 |
+
const int batch_size,
|
| 627 |
+
const int spatial_size,
|
| 628 |
+
const int num_heads,
|
| 629 |
+
const int channels,
|
| 630 |
+
const int num_levels,
|
| 631 |
+
const int num_query,
|
| 632 |
+
const int num_point,
|
| 633 |
+
scalar_t *grad_value,
|
| 634 |
+
scalar_t *grad_sampling_loc,
|
| 635 |
+
scalar_t *grad_attn_weight)
|
| 636 |
+
{
|
| 637 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 638 |
+
{
|
| 639 |
+
extern __shared__ int _s[];
|
| 640 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 641 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 642 |
+
unsigned int tid = threadIdx.x;
|
| 643 |
+
int _temp = index;
|
| 644 |
+
const int c_col = _temp % channels;
|
| 645 |
+
_temp /= channels;
|
| 646 |
+
const int sampling_index = _temp;
|
| 647 |
+
const int m_col = _temp % num_heads;
|
| 648 |
+
_temp /= num_heads;
|
| 649 |
+
const int q_col = _temp % num_query;
|
| 650 |
+
_temp /= num_query;
|
| 651 |
+
const int b_col = _temp;
|
| 652 |
+
|
| 653 |
+
const scalar_t top_grad = grad_col[index];
|
| 654 |
+
|
| 655 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 656 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 657 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 658 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 659 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 660 |
+
const int grad_weight_stride = 1;
|
| 661 |
+
const int grad_loc_stride = 2;
|
| 662 |
+
const int qid_stride = num_heads * channels;
|
| 663 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 664 |
+
|
| 665 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 666 |
+
{
|
| 667 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 668 |
+
const int spatial_h_ptr = l_col << 1;
|
| 669 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 670 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 671 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 672 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 673 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 674 |
+
|
| 675 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 676 |
+
{
|
| 677 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 678 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 679 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 680 |
+
|
| 681 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 682 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 683 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 684 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 685 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 686 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 687 |
+
{
|
| 688 |
+
ms_deform_attn_col2im_bilinear(
|
| 689 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 690 |
+
top_grad, weight, grad_value_ptr,
|
| 691 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
__syncthreads();
|
| 695 |
+
|
| 696 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 697 |
+
{
|
| 698 |
+
if (tid < s) {
|
| 699 |
+
const unsigned int xid1 = tid << 1;
|
| 700 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 701 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 702 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 703 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 704 |
+
if (tid + (s << 1) < spre)
|
| 705 |
+
{
|
| 706 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 707 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 708 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
__syncthreads();
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
if (tid == 0)
|
| 715 |
+
{
|
| 716 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 717 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 718 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 719 |
+
}
|
| 720 |
+
__syncthreads();
|
| 721 |
+
|
| 722 |
+
data_weight_ptr += 1;
|
| 723 |
+
data_loc_w_ptr += 2;
|
| 724 |
+
grad_attn_weight += grad_weight_stride;
|
| 725 |
+
grad_sampling_loc += grad_loc_stride;
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
template <typename scalar_t>
|
| 732 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
| 733 |
+
const scalar_t *grad_col,
|
| 734 |
+
const scalar_t *data_value,
|
| 735 |
+
const int64_t *data_spatial_shapes,
|
| 736 |
+
const int64_t *data_level_start_index,
|
| 737 |
+
const scalar_t *data_sampling_loc,
|
| 738 |
+
const scalar_t *data_attn_weight,
|
| 739 |
+
const int batch_size,
|
| 740 |
+
const int spatial_size,
|
| 741 |
+
const int num_heads,
|
| 742 |
+
const int channels,
|
| 743 |
+
const int num_levels,
|
| 744 |
+
const int num_query,
|
| 745 |
+
const int num_point,
|
| 746 |
+
scalar_t *grad_value,
|
| 747 |
+
scalar_t *grad_sampling_loc,
|
| 748 |
+
scalar_t *grad_attn_weight)
|
| 749 |
+
{
|
| 750 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 751 |
+
{
|
| 752 |
+
extern __shared__ int _s[];
|
| 753 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 754 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 755 |
+
unsigned int tid = threadIdx.x;
|
| 756 |
+
int _temp = index;
|
| 757 |
+
const int c_col = _temp % channels;
|
| 758 |
+
_temp /= channels;
|
| 759 |
+
const int sampling_index = _temp;
|
| 760 |
+
const int m_col = _temp % num_heads;
|
| 761 |
+
_temp /= num_heads;
|
| 762 |
+
const int q_col = _temp % num_query;
|
| 763 |
+
_temp /= num_query;
|
| 764 |
+
const int b_col = _temp;
|
| 765 |
+
|
| 766 |
+
const scalar_t top_grad = grad_col[index];
|
| 767 |
+
|
| 768 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 769 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 770 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 771 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 772 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 773 |
+
const int grad_weight_stride = 1;
|
| 774 |
+
const int grad_loc_stride = 2;
|
| 775 |
+
const int qid_stride = num_heads * channels;
|
| 776 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 777 |
+
|
| 778 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 779 |
+
{
|
| 780 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 781 |
+
const int spatial_h_ptr = l_col << 1;
|
| 782 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 783 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 784 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 785 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 786 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 787 |
+
|
| 788 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 789 |
+
{
|
| 790 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 791 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 792 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 793 |
+
|
| 794 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 795 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 796 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 797 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 798 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 799 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 800 |
+
{
|
| 801 |
+
ms_deform_attn_col2im_bilinear(
|
| 802 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 803 |
+
top_grad, weight, grad_value_ptr,
|
| 804 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
__syncthreads();
|
| 808 |
+
|
| 809 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 810 |
+
{
|
| 811 |
+
if (tid < s) {
|
| 812 |
+
const unsigned int xid1 = tid << 1;
|
| 813 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 814 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 815 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 816 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 817 |
+
if (tid + (s << 1) < spre)
|
| 818 |
+
{
|
| 819 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 820 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 821 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 822 |
+
}
|
| 823 |
+
}
|
| 824 |
+
__syncthreads();
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
if (tid == 0)
|
| 828 |
+
{
|
| 829 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
| 830 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
| 831 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
| 832 |
+
}
|
| 833 |
+
__syncthreads();
|
| 834 |
+
|
| 835 |
+
data_weight_ptr += 1;
|
| 836 |
+
data_loc_w_ptr += 2;
|
| 837 |
+
grad_attn_weight += grad_weight_stride;
|
| 838 |
+
grad_sampling_loc += grad_loc_stride;
|
| 839 |
+
}
|
| 840 |
+
}
|
| 841 |
+
}
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
template <typename scalar_t>
|
| 846 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
| 847 |
+
const scalar_t *grad_col,
|
| 848 |
+
const scalar_t *data_value,
|
| 849 |
+
const int64_t *data_spatial_shapes,
|
| 850 |
+
const int64_t *data_level_start_index,
|
| 851 |
+
const scalar_t *data_sampling_loc,
|
| 852 |
+
const scalar_t *data_attn_weight,
|
| 853 |
+
const int batch_size,
|
| 854 |
+
const int spatial_size,
|
| 855 |
+
const int num_heads,
|
| 856 |
+
const int channels,
|
| 857 |
+
const int num_levels,
|
| 858 |
+
const int num_query,
|
| 859 |
+
const int num_point,
|
| 860 |
+
scalar_t *grad_value,
|
| 861 |
+
scalar_t *grad_sampling_loc,
|
| 862 |
+
scalar_t *grad_attn_weight)
|
| 863 |
+
{
|
| 864 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 865 |
+
{
|
| 866 |
+
int _temp = index;
|
| 867 |
+
const int c_col = _temp % channels;
|
| 868 |
+
_temp /= channels;
|
| 869 |
+
const int sampling_index = _temp;
|
| 870 |
+
const int m_col = _temp % num_heads;
|
| 871 |
+
_temp /= num_heads;
|
| 872 |
+
const int q_col = _temp % num_query;
|
| 873 |
+
_temp /= num_query;
|
| 874 |
+
const int b_col = _temp;
|
| 875 |
+
|
| 876 |
+
const scalar_t top_grad = grad_col[index];
|
| 877 |
+
|
| 878 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 879 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 880 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 881 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 882 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 883 |
+
const int grad_weight_stride = 1;
|
| 884 |
+
const int grad_loc_stride = 2;
|
| 885 |
+
const int qid_stride = num_heads * channels;
|
| 886 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 887 |
+
|
| 888 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 889 |
+
{
|
| 890 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 891 |
+
const int spatial_h_ptr = l_col << 1;
|
| 892 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 893 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 894 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 895 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 896 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 897 |
+
|
| 898 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 899 |
+
{
|
| 900 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 901 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 902 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 903 |
+
|
| 904 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 905 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 906 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 907 |
+
{
|
| 908 |
+
ms_deform_attn_col2im_bilinear_gm(
|
| 909 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 910 |
+
top_grad, weight, grad_value_ptr,
|
| 911 |
+
grad_sampling_loc, grad_attn_weight);
|
| 912 |
+
}
|
| 913 |
+
data_weight_ptr += 1;
|
| 914 |
+
data_loc_w_ptr += 2;
|
| 915 |
+
grad_attn_weight += grad_weight_stride;
|
| 916 |
+
grad_sampling_loc += grad_loc_stride;
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
template <typename scalar_t>
|
| 924 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
| 925 |
+
const scalar_t* data_value,
|
| 926 |
+
const int64_t* data_spatial_shapes,
|
| 927 |
+
const int64_t* data_level_start_index,
|
| 928 |
+
const scalar_t* data_sampling_loc,
|
| 929 |
+
const scalar_t* data_attn_weight,
|
| 930 |
+
const int batch_size,
|
| 931 |
+
const int spatial_size,
|
| 932 |
+
const int num_heads,
|
| 933 |
+
const int channels,
|
| 934 |
+
const int num_levels,
|
| 935 |
+
const int num_query,
|
| 936 |
+
const int num_point,
|
| 937 |
+
scalar_t* data_col)
|
| 938 |
+
{
|
| 939 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 940 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 941 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 942 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
| 943 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 944 |
+
0, stream>>>(
|
| 945 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
| 946 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
| 947 |
+
|
| 948 |
+
cudaError_t err = cudaGetLastError();
|
| 949 |
+
if (err != cudaSuccess)
|
| 950 |
+
{
|
| 951 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
template <typename scalar_t>
|
| 957 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
| 958 |
+
const scalar_t* grad_col,
|
| 959 |
+
const scalar_t* data_value,
|
| 960 |
+
const int64_t * data_spatial_shapes,
|
| 961 |
+
const int64_t * data_level_start_index,
|
| 962 |
+
const scalar_t * data_sampling_loc,
|
| 963 |
+
const scalar_t * data_attn_weight,
|
| 964 |
+
const int batch_size,
|
| 965 |
+
const int spatial_size,
|
| 966 |
+
const int num_heads,
|
| 967 |
+
const int channels,
|
| 968 |
+
const int num_levels,
|
| 969 |
+
const int num_query,
|
| 970 |
+
const int num_point,
|
| 971 |
+
scalar_t* grad_value,
|
| 972 |
+
scalar_t* grad_sampling_loc,
|
| 973 |
+
scalar_t* grad_attn_weight)
|
| 974 |
+
{
|
| 975 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
| 976 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 977 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 978 |
+
if (channels > 1024)
|
| 979 |
+
{
|
| 980 |
+
if ((channels & 1023) == 0)
|
| 981 |
+
{
|
| 982 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
| 983 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 984 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 985 |
+
num_kernels,
|
| 986 |
+
grad_col,
|
| 987 |
+
data_value,
|
| 988 |
+
data_spatial_shapes,
|
| 989 |
+
data_level_start_index,
|
| 990 |
+
data_sampling_loc,
|
| 991 |
+
data_attn_weight,
|
| 992 |
+
batch_size,
|
| 993 |
+
spatial_size,
|
| 994 |
+
num_heads,
|
| 995 |
+
channels,
|
| 996 |
+
num_levels,
|
| 997 |
+
num_query,
|
| 998 |
+
num_point,
|
| 999 |
+
grad_value,
|
| 1000 |
+
grad_sampling_loc,
|
| 1001 |
+
grad_attn_weight);
|
| 1002 |
+
}
|
| 1003 |
+
else
|
| 1004 |
+
{
|
| 1005 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
| 1006 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1007 |
+
0, stream>>>(
|
| 1008 |
+
num_kernels,
|
| 1009 |
+
grad_col,
|
| 1010 |
+
data_value,
|
| 1011 |
+
data_spatial_shapes,
|
| 1012 |
+
data_level_start_index,
|
| 1013 |
+
data_sampling_loc,
|
| 1014 |
+
data_attn_weight,
|
| 1015 |
+
batch_size,
|
| 1016 |
+
spatial_size,
|
| 1017 |
+
num_heads,
|
| 1018 |
+
channels,
|
| 1019 |
+
num_levels,
|
| 1020 |
+
num_query,
|
| 1021 |
+
num_point,
|
| 1022 |
+
grad_value,
|
| 1023 |
+
grad_sampling_loc,
|
| 1024 |
+
grad_attn_weight);
|
| 1025 |
+
}
|
| 1026 |
+
}
|
| 1027 |
+
else{
|
| 1028 |
+
switch(channels)
|
| 1029 |
+
{
|
| 1030 |
+
case 1:
|
| 1031 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
| 1032 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1033 |
+
0, stream>>>(
|
| 1034 |
+
num_kernels,
|
| 1035 |
+
grad_col,
|
| 1036 |
+
data_value,
|
| 1037 |
+
data_spatial_shapes,
|
| 1038 |
+
data_level_start_index,
|
| 1039 |
+
data_sampling_loc,
|
| 1040 |
+
data_attn_weight,
|
| 1041 |
+
batch_size,
|
| 1042 |
+
spatial_size,
|
| 1043 |
+
num_heads,
|
| 1044 |
+
channels,
|
| 1045 |
+
num_levels,
|
| 1046 |
+
num_query,
|
| 1047 |
+
num_point,
|
| 1048 |
+
grad_value,
|
| 1049 |
+
grad_sampling_loc,
|
| 1050 |
+
grad_attn_weight);
|
| 1051 |
+
break;
|
| 1052 |
+
case 2:
|
| 1053 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
| 1054 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1055 |
+
0, stream>>>(
|
| 1056 |
+
num_kernels,
|
| 1057 |
+
grad_col,
|
| 1058 |
+
data_value,
|
| 1059 |
+
data_spatial_shapes,
|
| 1060 |
+
data_level_start_index,
|
| 1061 |
+
data_sampling_loc,
|
| 1062 |
+
data_attn_weight,
|
| 1063 |
+
batch_size,
|
| 1064 |
+
spatial_size,
|
| 1065 |
+
num_heads,
|
| 1066 |
+
channels,
|
| 1067 |
+
num_levels,
|
| 1068 |
+
num_query,
|
| 1069 |
+
num_point,
|
| 1070 |
+
grad_value,
|
| 1071 |
+
grad_sampling_loc,
|
| 1072 |
+
grad_attn_weight);
|
| 1073 |
+
break;
|
| 1074 |
+
case 4:
|
| 1075 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
| 1076 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1077 |
+
0, stream>>>(
|
| 1078 |
+
num_kernels,
|
| 1079 |
+
grad_col,
|
| 1080 |
+
data_value,
|
| 1081 |
+
data_spatial_shapes,
|
| 1082 |
+
data_level_start_index,
|
| 1083 |
+
data_sampling_loc,
|
| 1084 |
+
data_attn_weight,
|
| 1085 |
+
batch_size,
|
| 1086 |
+
spatial_size,
|
| 1087 |
+
num_heads,
|
| 1088 |
+
channels,
|
| 1089 |
+
num_levels,
|
| 1090 |
+
num_query,
|
| 1091 |
+
num_point,
|
| 1092 |
+
grad_value,
|
| 1093 |
+
grad_sampling_loc,
|
| 1094 |
+
grad_attn_weight);
|
| 1095 |
+
break;
|
| 1096 |
+
case 8:
|
| 1097 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
| 1098 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1099 |
+
0, stream>>>(
|
| 1100 |
+
num_kernels,
|
| 1101 |
+
grad_col,
|
| 1102 |
+
data_value,
|
| 1103 |
+
data_spatial_shapes,
|
| 1104 |
+
data_level_start_index,
|
| 1105 |
+
data_sampling_loc,
|
| 1106 |
+
data_attn_weight,
|
| 1107 |
+
batch_size,
|
| 1108 |
+
spatial_size,
|
| 1109 |
+
num_heads,
|
| 1110 |
+
channels,
|
| 1111 |
+
num_levels,
|
| 1112 |
+
num_query,
|
| 1113 |
+
num_point,
|
| 1114 |
+
grad_value,
|
| 1115 |
+
grad_sampling_loc,
|
| 1116 |
+
grad_attn_weight);
|
| 1117 |
+
break;
|
| 1118 |
+
case 16:
|
| 1119 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
| 1120 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1121 |
+
0, stream>>>(
|
| 1122 |
+
num_kernels,
|
| 1123 |
+
grad_col,
|
| 1124 |
+
data_value,
|
| 1125 |
+
data_spatial_shapes,
|
| 1126 |
+
data_level_start_index,
|
| 1127 |
+
data_sampling_loc,
|
| 1128 |
+
data_attn_weight,
|
| 1129 |
+
batch_size,
|
| 1130 |
+
spatial_size,
|
| 1131 |
+
num_heads,
|
| 1132 |
+
channels,
|
| 1133 |
+
num_levels,
|
| 1134 |
+
num_query,
|
| 1135 |
+
num_point,
|
| 1136 |
+
grad_value,
|
| 1137 |
+
grad_sampling_loc,
|
| 1138 |
+
grad_attn_weight);
|
| 1139 |
+
break;
|
| 1140 |
+
case 32:
|
| 1141 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
| 1142 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1143 |
+
0, stream>>>(
|
| 1144 |
+
num_kernels,
|
| 1145 |
+
grad_col,
|
| 1146 |
+
data_value,
|
| 1147 |
+
data_spatial_shapes,
|
| 1148 |
+
data_level_start_index,
|
| 1149 |
+
data_sampling_loc,
|
| 1150 |
+
data_attn_weight,
|
| 1151 |
+
batch_size,
|
| 1152 |
+
spatial_size,
|
| 1153 |
+
num_heads,
|
| 1154 |
+
channels,
|
| 1155 |
+
num_levels,
|
| 1156 |
+
num_query,
|
| 1157 |
+
num_point,
|
| 1158 |
+
grad_value,
|
| 1159 |
+
grad_sampling_loc,
|
| 1160 |
+
grad_attn_weight);
|
| 1161 |
+
break;
|
| 1162 |
+
case 64:
|
| 1163 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
| 1164 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1165 |
+
0, stream>>>(
|
| 1166 |
+
num_kernels,
|
| 1167 |
+
grad_col,
|
| 1168 |
+
data_value,
|
| 1169 |
+
data_spatial_shapes,
|
| 1170 |
+
data_level_start_index,
|
| 1171 |
+
data_sampling_loc,
|
| 1172 |
+
data_attn_weight,
|
| 1173 |
+
batch_size,
|
| 1174 |
+
spatial_size,
|
| 1175 |
+
num_heads,
|
| 1176 |
+
channels,
|
| 1177 |
+
num_levels,
|
| 1178 |
+
num_query,
|
| 1179 |
+
num_point,
|
| 1180 |
+
grad_value,
|
| 1181 |
+
grad_sampling_loc,
|
| 1182 |
+
grad_attn_weight);
|
| 1183 |
+
break;
|
| 1184 |
+
case 128:
|
| 1185 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
| 1186 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1187 |
+
0, stream>>>(
|
| 1188 |
+
num_kernels,
|
| 1189 |
+
grad_col,
|
| 1190 |
+
data_value,
|
| 1191 |
+
data_spatial_shapes,
|
| 1192 |
+
data_level_start_index,
|
| 1193 |
+
data_sampling_loc,
|
| 1194 |
+
data_attn_weight,
|
| 1195 |
+
batch_size,
|
| 1196 |
+
spatial_size,
|
| 1197 |
+
num_heads,
|
| 1198 |
+
channels,
|
| 1199 |
+
num_levels,
|
| 1200 |
+
num_query,
|
| 1201 |
+
num_point,
|
| 1202 |
+
grad_value,
|
| 1203 |
+
grad_sampling_loc,
|
| 1204 |
+
grad_attn_weight);
|
| 1205 |
+
break;
|
| 1206 |
+
case 256:
|
| 1207 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
| 1208 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1209 |
+
0, stream>>>(
|
| 1210 |
+
num_kernels,
|
| 1211 |
+
grad_col,
|
| 1212 |
+
data_value,
|
| 1213 |
+
data_spatial_shapes,
|
| 1214 |
+
data_level_start_index,
|
| 1215 |
+
data_sampling_loc,
|
| 1216 |
+
data_attn_weight,
|
| 1217 |
+
batch_size,
|
| 1218 |
+
spatial_size,
|
| 1219 |
+
num_heads,
|
| 1220 |
+
channels,
|
| 1221 |
+
num_levels,
|
| 1222 |
+
num_query,
|
| 1223 |
+
num_point,
|
| 1224 |
+
grad_value,
|
| 1225 |
+
grad_sampling_loc,
|
| 1226 |
+
grad_attn_weight);
|
| 1227 |
+
break;
|
| 1228 |
+
case 512:
|
| 1229 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
| 1230 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1231 |
+
0, stream>>>(
|
| 1232 |
+
num_kernels,
|
| 1233 |
+
grad_col,
|
| 1234 |
+
data_value,
|
| 1235 |
+
data_spatial_shapes,
|
| 1236 |
+
data_level_start_index,
|
| 1237 |
+
data_sampling_loc,
|
| 1238 |
+
data_attn_weight,
|
| 1239 |
+
batch_size,
|
| 1240 |
+
spatial_size,
|
| 1241 |
+
num_heads,
|
| 1242 |
+
channels,
|
| 1243 |
+
num_levels,
|
| 1244 |
+
num_query,
|
| 1245 |
+
num_point,
|
| 1246 |
+
grad_value,
|
| 1247 |
+
grad_sampling_loc,
|
| 1248 |
+
grad_attn_weight);
|
| 1249 |
+
break;
|
| 1250 |
+
case 1024:
|
| 1251 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
| 1252 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1253 |
+
0, stream>>>(
|
| 1254 |
+
num_kernels,
|
| 1255 |
+
grad_col,
|
| 1256 |
+
data_value,
|
| 1257 |
+
data_spatial_shapes,
|
| 1258 |
+
data_level_start_index,
|
| 1259 |
+
data_sampling_loc,
|
| 1260 |
+
data_attn_weight,
|
| 1261 |
+
batch_size,
|
| 1262 |
+
spatial_size,
|
| 1263 |
+
num_heads,
|
| 1264 |
+
channels,
|
| 1265 |
+
num_levels,
|
| 1266 |
+
num_query,
|
| 1267 |
+
num_point,
|
| 1268 |
+
grad_value,
|
| 1269 |
+
grad_sampling_loc,
|
| 1270 |
+
grad_attn_weight);
|
| 1271 |
+
break;
|
| 1272 |
+
default:
|
| 1273 |
+
if (channels < 64)
|
| 1274 |
+
{
|
| 1275 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
| 1276 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1277 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1278 |
+
num_kernels,
|
| 1279 |
+
grad_col,
|
| 1280 |
+
data_value,
|
| 1281 |
+
data_spatial_shapes,
|
| 1282 |
+
data_level_start_index,
|
| 1283 |
+
data_sampling_loc,
|
| 1284 |
+
data_attn_weight,
|
| 1285 |
+
batch_size,
|
| 1286 |
+
spatial_size,
|
| 1287 |
+
num_heads,
|
| 1288 |
+
channels,
|
| 1289 |
+
num_levels,
|
| 1290 |
+
num_query,
|
| 1291 |
+
num_point,
|
| 1292 |
+
grad_value,
|
| 1293 |
+
grad_sampling_loc,
|
| 1294 |
+
grad_attn_weight);
|
| 1295 |
+
}
|
| 1296 |
+
else
|
| 1297 |
+
{
|
| 1298 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
| 1299 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1300 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1301 |
+
num_kernels,
|
| 1302 |
+
grad_col,
|
| 1303 |
+
data_value,
|
| 1304 |
+
data_spatial_shapes,
|
| 1305 |
+
data_level_start_index,
|
| 1306 |
+
data_sampling_loc,
|
| 1307 |
+
data_attn_weight,
|
| 1308 |
+
batch_size,
|
| 1309 |
+
spatial_size,
|
| 1310 |
+
num_heads,
|
| 1311 |
+
channels,
|
| 1312 |
+
num_levels,
|
| 1313 |
+
num_query,
|
| 1314 |
+
num_point,
|
| 1315 |
+
grad_value,
|
| 1316 |
+
grad_sampling_loc,
|
| 1317 |
+
grad_attn_weight);
|
| 1318 |
+
}
|
| 1319 |
+
}
|
| 1320 |
+
}
|
| 1321 |
+
cudaError_t err = cudaGetLastError();
|
| 1322 |
+
if (err != cudaSuccess)
|
| 1323 |
+
{
|
| 1324 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 1325 |
+
}
|
| 1326 |
+
|
| 1327 |
+
}
|
detect_tools/upn/ops/src/ms_deform_attn.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "cpu/ms_deform_attn_cpu.h"
|
| 14 |
+
|
| 15 |
+
#ifdef WITH_CUDA
|
| 16 |
+
#include "cuda/ms_deform_attn_cuda.h"
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
at::Tensor
|
| 21 |
+
ms_deform_attn_forward(
|
| 22 |
+
const at::Tensor &value,
|
| 23 |
+
const at::Tensor &spatial_shapes,
|
| 24 |
+
const at::Tensor &level_start_index,
|
| 25 |
+
const at::Tensor &sampling_loc,
|
| 26 |
+
const at::Tensor &attn_weight,
|
| 27 |
+
const int im2col_step)
|
| 28 |
+
{
|
| 29 |
+
if (value.is_cuda())
|
| 30 |
+
{
|
| 31 |
+
#ifdef WITH_CUDA
|
| 32 |
+
return ms_deform_attn_cuda_forward(
|
| 33 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
| 34 |
+
#else
|
| 35 |
+
AT_ERROR("Not compiled with GPU support");
|
| 36 |
+
#endif
|
| 37 |
+
}
|
| 38 |
+
AT_ERROR("Not implemented on the CPU");
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
std::vector<at::Tensor>
|
| 42 |
+
ms_deform_attn_backward(
|
| 43 |
+
const at::Tensor &value,
|
| 44 |
+
const at::Tensor &spatial_shapes,
|
| 45 |
+
const at::Tensor &level_start_index,
|
| 46 |
+
const at::Tensor &sampling_loc,
|
| 47 |
+
const at::Tensor &attn_weight,
|
| 48 |
+
const at::Tensor &grad_output,
|
| 49 |
+
const int im2col_step)
|
| 50 |
+
{
|
| 51 |
+
if (value.is_cuda())
|
| 52 |
+
{
|
| 53 |
+
#ifdef WITH_CUDA
|
| 54 |
+
return ms_deform_attn_cuda_backward(
|
| 55 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
| 56 |
+
#else
|
| 57 |
+
AT_ERROR("Not compiled with GPU support");
|
| 58 |
+
#endif
|
| 59 |
+
}
|
| 60 |
+
AT_ERROR("Not implemented on the CPU");
|
| 61 |
+
}
|
| 62 |
+
|
detect_tools/upn/ops/src/vision.cpp
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include "ms_deform_attn.h"
|
| 12 |
+
|
| 13 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 14 |
+
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
| 15 |
+
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
| 16 |
+
}
|
detect_tools/upn/ops/test.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
from __future__ import division
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torch.autograd import gradcheck
|
| 17 |
+
|
| 18 |
+
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
N, M, D = 1, 2, 2
|
| 22 |
+
Lq, L, P = 2, 2, 2
|
| 23 |
+
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
| 24 |
+
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
|
| 25 |
+
S = sum([(H*W).item() for H, W in shapes])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
torch.manual_seed(3)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def check_forward_equal_with_pytorch_double():
|
| 33 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
| 34 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
| 35 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
| 36 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
| 37 |
+
im2col_step = 2
|
| 38 |
+
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
|
| 39 |
+
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
|
| 40 |
+
fwdok = torch.allclose(output_cuda, output_pytorch)
|
| 41 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
| 42 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
| 43 |
+
|
| 44 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def check_forward_equal_with_pytorch_float():
|
| 49 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
| 50 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
| 51 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
| 52 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
| 53 |
+
im2col_step = 2
|
| 54 |
+
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
|
| 55 |
+
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
|
| 56 |
+
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
| 57 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
| 58 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
| 59 |
+
|
| 60 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
|
| 64 |
+
|
| 65 |
+
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
| 66 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
| 67 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
| 68 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
| 69 |
+
im2col_step = 2
|
| 70 |
+
func = MSDeformAttnFunction.apply
|
| 71 |
+
|
| 72 |
+
value.requires_grad = grad_value
|
| 73 |
+
sampling_locations.requires_grad = grad_sampling_loc
|
| 74 |
+
attention_weights.requires_grad = grad_attn_weight
|
| 75 |
+
|
| 76 |
+
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
|
| 77 |
+
|
| 78 |
+
print(f'* {gradok} check_gradient_numerical(D={channels})')
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == '__main__':
|
| 82 |
+
check_forward_equal_with_pytorch_double()
|
| 83 |
+
check_forward_equal_with_pytorch_float()
|
| 84 |
+
|
| 85 |
+
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
| 86 |
+
check_gradient_numerical(channels, True, True, True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
detect_tools/upn/requirments.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
mmengine==0.8.2
|
detect_tools/upn/transforms/transform.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.transforms.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def resize(image, target, size, max_size=None):
|
| 7 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 8 |
+
|
| 9 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 10 |
+
w, h = image_size
|
| 11 |
+
if max_size is not None:
|
| 12 |
+
min_original_size = float(min((w, h)))
|
| 13 |
+
max_original_size = float(max((w, h)))
|
| 14 |
+
if max_original_size / min_original_size * size > max_size:
|
| 15 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
| 16 |
+
|
| 17 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 18 |
+
return (h, w)
|
| 19 |
+
|
| 20 |
+
if w < h:
|
| 21 |
+
ow = size
|
| 22 |
+
oh = int(size * h / w)
|
| 23 |
+
else:
|
| 24 |
+
oh = size
|
| 25 |
+
ow = int(size * w / h)
|
| 26 |
+
|
| 27 |
+
return (oh, ow)
|
| 28 |
+
|
| 29 |
+
def get_size(image_size, size, max_size=None):
|
| 30 |
+
if isinstance(size, (list, tuple)):
|
| 31 |
+
return size[::-1]
|
| 32 |
+
else:
|
| 33 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 34 |
+
|
| 35 |
+
size = get_size(image.size, size, max_size)
|
| 36 |
+
rescaled_image = F.resize(image, size)
|
| 37 |
+
|
| 38 |
+
if target is None:
|
| 39 |
+
return rescaled_image, None
|
| 40 |
+
|
| 41 |
+
ratios = tuple(
|
| 42 |
+
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
|
| 43 |
+
)
|
| 44 |
+
ratio_width, ratio_height = ratios
|
| 45 |
+
|
| 46 |
+
target = target.copy()
|
| 47 |
+
if "exampler_box" in target:
|
| 48 |
+
boxes = target["exampler_box"]
|
| 49 |
+
if isinstance(boxes, torch.Tensor):
|
| 50 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 51 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
| 52 |
+
)
|
| 53 |
+
target["exampler_box"] = scaled_boxes
|
| 54 |
+
elif isinstance(boxes, dict):
|
| 55 |
+
for k, v in boxes.items():
|
| 56 |
+
scaled_boxes = v * torch.as_tensor(
|
| 57 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
| 58 |
+
)
|
| 59 |
+
target["exampler_box"][k] = scaled_boxes
|
| 60 |
+
|
| 61 |
+
if "demo_pos_exampler_box" in target:
|
| 62 |
+
boxes = target["demo_pos_exampler_box"]
|
| 63 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 64 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
| 65 |
+
)
|
| 66 |
+
target["demo_pos_exampler_box"] = scaled_boxes
|
| 67 |
+
|
| 68 |
+
if "demo_neg_exampler_box" in target:
|
| 69 |
+
boxes = target["demo_neg_exampler_box"]
|
| 70 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 71 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
| 72 |
+
)
|
| 73 |
+
target["demo_neg_exampler_box"] = scaled_boxes
|
| 74 |
+
|
| 75 |
+
if "boxes" in target:
|
| 76 |
+
boxes = target["boxes"]
|
| 77 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 78 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
| 79 |
+
)
|
| 80 |
+
target["boxes"] = scaled_boxes
|
| 81 |
+
|
| 82 |
+
if "area" in target:
|
| 83 |
+
area = target["area"]
|
| 84 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 85 |
+
target["area"] = scaled_area
|
| 86 |
+
|
| 87 |
+
h, w = size
|
| 88 |
+
target["size"] = torch.tensor([h, w])
|
| 89 |
+
|
| 90 |
+
return rescaled_image, target
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class RandomResize(object):
|
| 94 |
+
|
| 95 |
+
def __init__(self, sizes, max_size=None):
|
| 96 |
+
assert isinstance(sizes, (list, tuple))
|
| 97 |
+
self.sizes = sizes
|
| 98 |
+
self.max_size = max_size
|
| 99 |
+
|
| 100 |
+
def __call__(self, img, target=None):
|
| 101 |
+
size = random.choice(self.sizes)
|
| 102 |
+
return resize(img, target, size, self.max_size)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class ToTensor(object):
|
| 106 |
+
|
| 107 |
+
def __call__(self, img, target):
|
| 108 |
+
return F.to_tensor(img), target
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Normalize(object):
|
| 112 |
+
|
| 113 |
+
def __init__(self, mean, std):
|
| 114 |
+
self.mean = mean
|
| 115 |
+
self.std = std
|
| 116 |
+
|
| 117 |
+
def __call__(self, image, target=None):
|
| 118 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 119 |
+
if target is None:
|
| 120 |
+
return image, None
|
| 121 |
+
target = target.copy()
|
| 122 |
+
h, w = image.shape[-2:]
|
| 123 |
+
return image, target
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Compose(object):
|
| 127 |
+
|
| 128 |
+
def __init__(self, transforms):
|
| 129 |
+
self.transforms = transforms
|
| 130 |
+
|
| 131 |
+
def __call__(self, image, target):
|
| 132 |
+
for t in self.transforms:
|
| 133 |
+
image, target = t(image, target)
|
| 134 |
+
return image, target
|
| 135 |
+
|
| 136 |
+
def __repr__(self):
|
| 137 |
+
format_string = self.__class__.__name__ + "("
|
| 138 |
+
for t in self.transforms:
|
| 139 |
+
format_string += "\n"
|
| 140 |
+
format_string += " {0}".format(t)
|
| 141 |
+
format_string += "\n)"
|
| 142 |
+
return format_string
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.6.0
|
| 2 |
+
torchvision==0.21.0
|
| 3 |
+
transformers==4.50.1
|
| 4 |
+
https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3-0.1.0-py3-none-any.whl
|
| 5 |
+
timm==1.0.9
|
| 6 |
+
accelerate==1.4.0
|
| 7 |
+
gradio
|
| 8 |
+
mmengine==0.8.2
|
| 9 |
+
einops
|
| 10 |
+
ninja
|
| 11 |
+
scikit-image
|
| 12 |
+
decord
|
| 13 |
+
scikit-learn
|
| 14 |
+
matplotlib
|
| 15 |
+
modelscope
|
| 16 |
+
https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/multiscaledeformableattention-1.0-cp310-cp310-linux_x86_64.whl
|
| 17 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 18 |
+
pycocotools
|
| 19 |
+
opencv-python
|
resources/__init__.py
ADDED
|
File without changes
|
vlm_fo1/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
vlm_fo1/constants.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LOGDIR = "."
|
| 2 |
+
|
| 3 |
+
global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 4 |
+
# Model Constants
|
| 5 |
+
IGNORE_INDEX = -100
|
| 6 |
+
IMAGE_TOKEN_INDEX = -200 #151656 #151655 #-200
|
| 7 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 8 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 9 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 10 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 11 |
+
|
| 12 |
+
# For Qwen2_5_VL
|
| 13 |
+
QWEN2_5_VL_IMAGE_TOKEN = "<|image_pad|>"
|
| 14 |
+
QWEN2_5_VL_IMAGE_TOKEN_INDEX = 151655
|
| 15 |
+
|
| 16 |
+
# For regions
|
| 17 |
+
DEFAULT_REGION_TOKEN = "<region<i>>"
|
| 18 |
+
DEFAULT_REGION_FEATURE_TOKEN = "<regionfeat>"
|
| 19 |
+
DEFAULT_REGION_INDEX = -300 #151654 #151654 #-300
|
| 20 |
+
|
| 21 |
+
# For Grounding
|
| 22 |
+
DEFAULT_GROUNDING_START = "<ground>"
|
| 23 |
+
DEFAULT_GROUNDING_END = "</ground>"
|
| 24 |
+
DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
|
| 25 |
+
DEFAULT_GROUNDING_OBJECTS_END = "</objects>"
|
| 26 |
+
|
| 27 |
+
# For Think
|
| 28 |
+
DEFAULT_THINK_START = "<think>"
|
| 29 |
+
DEFAULT_THINK_END = "</think>"
|
vlm_fo1/mm_utils.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from PIL import ImageDraw, ImageOps
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
import base64
|
| 5 |
+
import re
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import StoppingCriteria
|
| 8 |
+
from vlm_fo1.constants import IMAGE_TOKEN_INDEX, DEFAULT_REGION_INDEX
|
| 9 |
+
import requests
|
| 10 |
+
from vlm_fo1.constants import (
|
| 11 |
+
IMAGE_TOKEN_INDEX,
|
| 12 |
+
DEFAULT_IMAGE_TOKEN,
|
| 13 |
+
DEFAULT_IM_START_TOKEN,
|
| 14 |
+
DEFAULT_IM_END_TOKEN,
|
| 15 |
+
IGNORE_INDEX,
|
| 16 |
+
DEFAULT_REGION_TOKEN,
|
| 17 |
+
DEFAULT_REGION_FEATURE_TOKEN
|
| 18 |
+
)
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import TextStreamer
|
| 21 |
+
import random
|
| 22 |
+
import re
|
| 23 |
+
from typing import List, Tuple
|
| 24 |
+
import io
|
| 25 |
+
import base64
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
| 29 |
+
"""
|
| 30 |
+
Tokenizes prompts containing <image> or <image_0>... special tokens.
|
| 31 |
+
|
| 32 |
+
If the prompt uses <image_0>, <image_1>, ..., each is replaced with a placeholder index (-200).
|
| 33 |
+
If the prompt uses <image>, it is replaced with image_token_index.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
prompt (str): The prompt potentially containing image tokens.
|
| 37 |
+
tokenizer: The tokenizer object.
|
| 38 |
+
image_token_index (int): Token id to use when encountering <image> token.
|
| 39 |
+
return_tensors (Optional[str]): If 'pt', return a torch tensor.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
List[int] or torch.Tensor: The tokenized input with image token indices inserted appropriately.
|
| 43 |
+
"""
|
| 44 |
+
if "<image_0>" in prompt:
|
| 45 |
+
# Case: prompt contains indexed image tokens like <image_0>, <image_1>, etc.
|
| 46 |
+
image_token_pattern = re.compile(r"<image_(\d+)>")
|
| 47 |
+
prompt_chunks = re.split(r'<image_[0-9]+>', prompt)
|
| 48 |
+
image_tags = image_token_pattern.findall(prompt)
|
| 49 |
+
|
| 50 |
+
input_ids = []
|
| 51 |
+
for i, chunk in enumerate(prompt_chunks):
|
| 52 |
+
input_ids.extend(tokenizer(chunk).input_ids)
|
| 53 |
+
if i < len(image_tags):
|
| 54 |
+
# Insert placeholder where <image_n> token was.
|
| 55 |
+
input_ids.append(-200)
|
| 56 |
+
else:
|
| 57 |
+
# Case: prompt contains plain <image> tokens.
|
| 58 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
| 59 |
+
|
| 60 |
+
def insert_separator(X, sep):
|
| 61 |
+
# Helper function to insert a separator token between chunks.
|
| 62 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 63 |
+
|
| 64 |
+
input_ids = []
|
| 65 |
+
offset = 0
|
| 66 |
+
# If first chunk starts with <bos> token, make sure to keep it only once.
|
| 67 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 68 |
+
offset = 1
|
| 69 |
+
input_ids.append(prompt_chunks[0][0])
|
| 70 |
+
|
| 71 |
+
# Insert image_token_index between chunks.
|
| 72 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 73 |
+
input_ids.extend(x[offset:])
|
| 74 |
+
# Optionally convert output to PyTorch tensor.
|
| 75 |
+
if return_tensors is not None:
|
| 76 |
+
if return_tensors == 'pt':
|
| 77 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 80 |
+
|
| 81 |
+
return input_ids
|
| 82 |
+
|
| 83 |
+
def tokenizer_image_region_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, region_token_index=DEFAULT_REGION_INDEX, return_tensors=None):
|
| 84 |
+
"""
|
| 85 |
+
Tokenizes prompts containing both <image> and <regionfeat> delimiters, inserting specified token indices.
|
| 86 |
+
|
| 87 |
+
Each <image> chunk is split, and within that chunk, <regionfeat> locations receive region_token_index.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
prompt (str): The prompt with <image> and <regionfeat> delimiters.
|
| 91 |
+
tokenizer: The tokenizer object.
|
| 92 |
+
image_token_index (int): Insert this at <image> splits.
|
| 93 |
+
region_token_index (int): Insert this at <regionfeat> splits.
|
| 94 |
+
return_tensors (Optional[str]): If 'pt', return torch tensor.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
List[int] or torch.Tensor: The tokenized input with region/image tokens placed.
|
| 98 |
+
"""
|
| 99 |
+
# Split by <image> tags first.
|
| 100 |
+
image_chunks = prompt.split('<image>')
|
| 101 |
+
|
| 102 |
+
prompt_chunks = []
|
| 103 |
+
for chunk in image_chunks:
|
| 104 |
+
# Split each image chunk by <regionfeat>.
|
| 105 |
+
obj_chunks = chunk.split('<regionfeat>')
|
| 106 |
+
# Tokenize each subchunk.
|
| 107 |
+
token_chunks = [tokenizer(c).input_ids for c in obj_chunks]
|
| 108 |
+
prompt_chunks.append(token_chunks)
|
| 109 |
+
|
| 110 |
+
input_ids = []
|
| 111 |
+
offset = 0
|
| 112 |
+
|
| 113 |
+
# If first chunk starts with <bos> token, include only once.
|
| 114 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and len(prompt_chunks[0][0]) > 0 and prompt_chunks[0][0][0] == tokenizer.bos_token_id:
|
| 115 |
+
offset = 1
|
| 116 |
+
input_ids.append(prompt_chunks[0][0][0])
|
| 117 |
+
|
| 118 |
+
# Stitch together all chunks with region/image tokens at appropriate locations.
|
| 119 |
+
for i, chunk_group in enumerate(prompt_chunks):
|
| 120 |
+
if len(chunk_group) > 0:
|
| 121 |
+
input_ids.extend(chunk_group[0][offset:])
|
| 122 |
+
for chunk in chunk_group[1:]:
|
| 123 |
+
input_ids.append(region_token_index)
|
| 124 |
+
input_ids.extend(chunk)
|
| 125 |
+
# Insert <image> token except after the last image chunk.
|
| 126 |
+
if i < len(prompt_chunks) - 1:
|
| 127 |
+
input_ids.append(image_token_index)
|
| 128 |
+
# Optionally convert to PyTorch tensor.
|
| 129 |
+
if return_tensors is not None:
|
| 130 |
+
if return_tensors == 'pt':
|
| 131 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 134 |
+
|
| 135 |
+
return input_ids
|
| 136 |
+
|
| 137 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 138 |
+
"""
|
| 139 |
+
Implements custom stopping criteria for generation based on keywords:
|
| 140 |
+
If the generated output contains any of the keywords, generation stops.
|
| 141 |
+
"""
|
| 142 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 143 |
+
self.keywords = keywords
|
| 144 |
+
self.keyword_ids = []
|
| 145 |
+
self.max_keyword_len = 0
|
| 146 |
+
for keyword in keywords:
|
| 147 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 148 |
+
# Remove BOS if present except for single token
|
| 149 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 150 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 151 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 152 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 153 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 154 |
+
self.tokenizer = tokenizer
|
| 155 |
+
# Track the generation start length
|
| 156 |
+
self.start_len = input_ids.shape[1]
|
| 157 |
+
|
| 158 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 159 |
+
"""
|
| 160 |
+
Checks if a keyword exists in the latest generated output ids for a single batch element.
|
| 161 |
+
"""
|
| 162 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 163 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 164 |
+
for keyword_id in self.keyword_ids:
|
| 165 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
|
| 166 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
| 167 |
+
return True
|
| 168 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 169 |
+
for keyword in self.keywords:
|
| 170 |
+
if keyword in outputs:
|
| 171 |
+
return True
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 175 |
+
"""
|
| 176 |
+
Checks for keywords in each batch item; stops when all have satisfied the keyword condition.
|
| 177 |
+
"""
|
| 178 |
+
outputs = []
|
| 179 |
+
for i in range(output_ids.shape[0]):
|
| 180 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 181 |
+
return all(outputs)
|
| 182 |
+
|
| 183 |
+
def load_image(image_file):
|
| 184 |
+
"""
|
| 185 |
+
Loads an image from a local path, base64 string, URL, or PIL.Image.
|
| 186 |
+
|
| 187 |
+
If the input image is smaller than 28x28, it will be resized to at least that size.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
image_file (str or PIL.Image.Image): Image source.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
PIL.Image.Image: Loaded image in RGB mode, at least 28x28 in size.
|
| 194 |
+
"""
|
| 195 |
+
if isinstance(image_file, Image.Image):
|
| 196 |
+
image = image_file.convert("RGB")
|
| 197 |
+
# Case: load from URL
|
| 198 |
+
elif image_file.startswith("http") or image_file.startswith("https"):
|
| 199 |
+
response = requests.get(image_file)
|
| 200 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 201 |
+
# Case: load from base64-encoded string
|
| 202 |
+
elif image_file.startswith("data:image/"):
|
| 203 |
+
image = image_file.replace("data:image/jpeg;base64,", "")
|
| 204 |
+
image_data = base64.b64decode(image)
|
| 205 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 206 |
+
elif isinstance(image_file, str):
|
| 207 |
+
# Case: load from local file path
|
| 208 |
+
image = Image.open(image_file).convert("RGB")
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"Unsupported image type: {type(image_file)}")
|
| 211 |
+
|
| 212 |
+
# Ensure minimum size 28x28
|
| 213 |
+
if image.width < 28 or image.height < 28:
|
| 214 |
+
image = image.resize((max(28, image.width), max(28, image.height)))
|
| 215 |
+
return image
|
| 216 |
+
|
| 217 |
+
def image_to_base64(img_pil):
|
| 218 |
+
"""
|
| 219 |
+
Encodes a PIL Image as JPEG in base64 format.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
img_pil (PIL.Image.Image): Source image.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
str: base64-encoded JPEG image string.
|
| 226 |
+
"""
|
| 227 |
+
with io.BytesIO() as buffer:
|
| 228 |
+
img_pil.save(buffer, format="JPEG")
|
| 229 |
+
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 230 |
+
return base64_image
|
| 231 |
+
|
| 232 |
+
def draw_bboxes_and_save(
|
| 233 |
+
image: Image.Image,
|
| 234 |
+
fo1_bboxes: dict = {},
|
| 235 |
+
detection_bboxes: List[Tuple[int, int, int, int]] = [],
|
| 236 |
+
output_path: str = 'output.jpg',
|
| 237 |
+
color: str = 'red',
|
| 238 |
+
total_color: str = 'green',
|
| 239 |
+
width: int = 2
|
| 240 |
+
) -> None:
|
| 241 |
+
"""
|
| 242 |
+
Draws bounding boxes (both ground-truth/proposed and detection) on a PIL image and saves result.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
image (PIL.Image.Image): Input PIL image object.
|
| 246 |
+
fo1_bboxes (dict): Label -> List[bbox] mapping for annotation bboxes.
|
| 247 |
+
detection_bboxes (List[Tuple]): List of detection bounding boxes; each bbox is (x_min, y_min, x_max, y_max).
|
| 248 |
+
output_path (str): Path to save the output image.
|
| 249 |
+
color (str): Color for fo1_bboxes.
|
| 250 |
+
total_color (str): Color for detection_bboxes.
|
| 251 |
+
width (int): Rectangle outline width.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
None
|
| 255 |
+
"""
|
| 256 |
+
draw = ImageDraw.Draw(image)
|
| 257 |
+
|
| 258 |
+
# Draw detection boxes with `total_color`
|
| 259 |
+
for bbox in detection_bboxes:
|
| 260 |
+
if len(bbox) != 4:
|
| 261 |
+
print(f"Warning: skip the invalid bbox {bbox}")
|
| 262 |
+
continue
|
| 263 |
+
shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
|
| 264 |
+
draw.rectangle(shape, outline=total_color, width=width)
|
| 265 |
+
|
| 266 |
+
# Draw annotated bboxes with labels and `color`
|
| 267 |
+
for bbox_label, bbox_list in fo1_bboxes.items():
|
| 268 |
+
for bbox in bbox_list:
|
| 269 |
+
if len(bbox) != 4:
|
| 270 |
+
print(f"Warning: skip the invalid bbox {bbox}")
|
| 271 |
+
continue
|
| 272 |
+
shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
|
| 273 |
+
draw.rectangle(shape, outline=color, width=width)
|
| 274 |
+
draw.text((bbox[0], bbox[1]), bbox_label, fill=color)
|
| 275 |
+
|
| 276 |
+
# Save output image (catching common IO exceptions).
|
| 277 |
+
try:
|
| 278 |
+
image.save(output_path)
|
| 279 |
+
print(f"The image has been successfully saved to: {output_path}")
|
| 280 |
+
except IOError as e:
|
| 281 |
+
print(f"Error: failed to save the image to {output_path}. Reason: {e}")
|
| 282 |
+
|
| 283 |
+
def adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w):
|
| 284 |
+
"""
|
| 285 |
+
Adjusts bounding boxes from original image size to resized image size, compensating for scaling.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
bbox_list (List[List[float]]): List of original boxes [x1, y1, x2, y2].
|
| 289 |
+
original_h (int): Original image height.
|
| 290 |
+
original_w (int): Original image width.
|
| 291 |
+
resize_h (int): Resized image height.
|
| 292 |
+
resize_w (int): Resized image width.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
List[List[float]]: Bounding boxes transformed to resized image coordinates.
|
| 296 |
+
"""
|
| 297 |
+
output_list = []
|
| 298 |
+
def adjust_bbox_range(bbox, width, height):
|
| 299 |
+
# Ensure all coordinates are within the original image border.
|
| 300 |
+
x1, y1, x2, y2 = bbox
|
| 301 |
+
x1 = max(0, min(width, x1))
|
| 302 |
+
y1 = max(0, min(height, y1))
|
| 303 |
+
x2 = max(0, min(width, x2))
|
| 304 |
+
y2 = max(0, min(height, y2))
|
| 305 |
+
return [x1, y1, x2, y2]
|
| 306 |
+
|
| 307 |
+
for bbox in bbox_list:
|
| 308 |
+
bbox = adjust_bbox_range(bbox, original_w, original_h)
|
| 309 |
+
bbox[0] = bbox[0] * resize_w / original_w
|
| 310 |
+
bbox[1] = bbox[1] * resize_h / original_h
|
| 311 |
+
bbox[2] = bbox[2] * resize_w / original_w
|
| 312 |
+
bbox[3] = bbox[3] * resize_h / original_h
|
| 313 |
+
output_list.append(bbox)
|
| 314 |
+
return output_list
|
| 315 |
+
|
| 316 |
+
def extract_predictions_to_bboxes(prediction: str, bbox_list):
|
| 317 |
+
"""
|
| 318 |
+
Parse prediction string in the expected format and map each ground label
|
| 319 |
+
to its corresponding bounding boxes using bbox_list.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
prediction (str): Model output string with <ground>...<objects>... markup.
|
| 323 |
+
bbox_list (List[List[float]]): Full list of predicted or reference bounding boxes.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
dict: label -> list of bboxes
|
| 327 |
+
"""
|
| 328 |
+
label_to_indexes = {}
|
| 329 |
+
label_to_bboxes = {}
|
| 330 |
+
|
| 331 |
+
match_pattern = r"<ground>(.*?)<\/ground><objects>(.*?)<\/objects>"
|
| 332 |
+
matches = re.findall(match_pattern, prediction)
|
| 333 |
+
|
| 334 |
+
for label_text, indexes in matches:
|
| 335 |
+
label_text = label_text.strip()
|
| 336 |
+
indexes_tags = re.findall(r"<region\d+>", indexes)
|
| 337 |
+
region_indexes = set([int(index.split("<region")[-1].split(">")[0]) for index in indexes_tags])
|
| 338 |
+
if label_text not in label_to_indexes:
|
| 339 |
+
label_to_indexes[label_text] = region_indexes
|
| 340 |
+
else:
|
| 341 |
+
label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes
|
| 342 |
+
|
| 343 |
+
for label, indexes in label_to_indexes.items():
|
| 344 |
+
label_to_bboxes[label] = [bbox_list[index] for index in indexes]
|
| 345 |
+
|
| 346 |
+
return label_to_bboxes
|
| 347 |
+
|
| 348 |
+
def extract_predictions_to_indexes(prediction: str):
|
| 349 |
+
"""
|
| 350 |
+
Parse prediction string, returning label -> set-of-indexes mapping.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
prediction (str): Model prediction output.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
dict: label -> set(int)
|
| 357 |
+
"""
|
| 358 |
+
label_to_indexes = {}
|
| 359 |
+
match_pattern = r"<ground>(.*?)<\/ground><objects>(.*?)<\/objects>"
|
| 360 |
+
matches = re.findall(match_pattern, prediction)
|
| 361 |
+
|
| 362 |
+
for label_text, indexes in matches:
|
| 363 |
+
label_text = label_text.strip()
|
| 364 |
+
indexes_tags = re.findall(r"<region\d+>", indexes)
|
| 365 |
+
region_indexes = set([int(index.split("<region")[-1].split(">")[0]) for index in indexes_tags])
|
| 366 |
+
if label_text not in label_to_indexes:
|
| 367 |
+
label_to_indexes[label_text] = region_indexes
|
| 368 |
+
else:
|
| 369 |
+
label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes
|
| 370 |
+
|
| 371 |
+
return label_to_indexes
|
| 372 |
+
|
| 373 |
+
def resize_shortest_edge_images_and_bboxes(
|
| 374 |
+
image_list: List[Image.Image],
|
| 375 |
+
bbox_lists: List,
|
| 376 |
+
candidate_sizes: List[int] = [],
|
| 377 |
+
max_size: int = 2048
|
| 378 |
+
):
|
| 379 |
+
"""
|
| 380 |
+
Randomly selects a size for the shortest edge, and proportionally resizes both images and bounding boxes.
|
| 381 |
+
|
| 382 |
+
The function maintains the image aspect ratio and ensures that the resized dimensions do not exceed the specified max_size.
|
| 383 |
+
Bounding boxes are transformed accordingly.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
image_list (List[Image.Image]): A list of PIL Image objects.
|
| 387 |
+
bbox_lists (List[List[List[float]]]): A list of lists of bounding boxes per image.
|
| 388 |
+
candidate_sizes (List[int]): Optional list of sizes to choose the target short edge from.
|
| 389 |
+
max_size (int): Maximum allowed long edge after resizing.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
Tuple[List[Image.Image], List[List[List[float]]]]:
|
| 393 |
+
([resized_image1, ...], [bbox_list1, ...]) - Possibly shape will match original (see below)
|
| 394 |
+
|
| 395 |
+
Raises:
|
| 396 |
+
ValueError: on input list length mismatch or emptiness.
|
| 397 |
+
"""
|
| 398 |
+
bbox_tensor = torch.tensor(bbox_lists)
|
| 399 |
+
# Normalize input: wrap bbox_lists into list-of-list, if needed.
|
| 400 |
+
if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4:
|
| 401 |
+
bbox_lists = [bbox_lists]
|
| 402 |
+
|
| 403 |
+
if not image_list or not bbox_lists:
|
| 404 |
+
raise ValueError("Input lists cannot be empty.")
|
| 405 |
+
if len(image_list) != len(bbox_lists):
|
| 406 |
+
raise ValueError("The lengths of the image list and the bounding box list must be the same.")
|
| 407 |
+
|
| 408 |
+
# Randomly select short edge size (if given candidate sizes)
|
| 409 |
+
if len(candidate_sizes) > 0:
|
| 410 |
+
target_size = random.choice(candidate_sizes)
|
| 411 |
+
else:
|
| 412 |
+
target_size = None
|
| 413 |
+
|
| 414 |
+
resized_images = []
|
| 415 |
+
transformed_bbox_lists = []
|
| 416 |
+
|
| 417 |
+
# Process each image and its corresponding bbox list
|
| 418 |
+
for img, bboxes in zip(image_list, bbox_lists):
|
| 419 |
+
original_width, original_height = img.size
|
| 420 |
+
|
| 421 |
+
# Determine scaling factor to bring short edge to target_size
|
| 422 |
+
shortest_side = min(original_width, original_height)
|
| 423 |
+
if target_size:
|
| 424 |
+
scale = target_size / shortest_side
|
| 425 |
+
else:
|
| 426 |
+
scale = 1.0
|
| 427 |
+
|
| 428 |
+
# Propose new height and width with this scale
|
| 429 |
+
new_height, new_width = int(original_height * scale), int(original_width * scale)
|
| 430 |
+
|
| 431 |
+
# If resulting long edge exceeds max_size, rescale down so that it fits.
|
| 432 |
+
longest_side = max(new_height, new_width)
|
| 433 |
+
if longest_side > max_size:
|
| 434 |
+
scale = max_size / longest_side
|
| 435 |
+
new_height, new_width = int(new_height * scale), int(new_width * scale)
|
| 436 |
+
# Ensure images are at least 28x28 (model may expect it)
|
| 437 |
+
new_width = max(28, new_width)
|
| 438 |
+
new_height = max(28, new_height)
|
| 439 |
+
|
| 440 |
+
# Resize image, using BICUBIC for quality if shape changes
|
| 441 |
+
if new_width == original_width and new_height == original_height:
|
| 442 |
+
resized_img = img
|
| 443 |
+
else:
|
| 444 |
+
resized_img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
| 445 |
+
resized_images.append(resized_img)
|
| 446 |
+
|
| 447 |
+
# Transform bounding boxes
|
| 448 |
+
current_transformed_bboxes = []
|
| 449 |
+
scale_ratio_x = new_width / original_width
|
| 450 |
+
scale_ratio_y = new_height / original_height
|
| 451 |
+
for bbox in bboxes:
|
| 452 |
+
x1, y1, x2, y2 = bbox
|
| 453 |
+
new_x1 = x1 * scale_ratio_x
|
| 454 |
+
new_y1 = y1 * scale_ratio_y
|
| 455 |
+
new_x2 = x2 * scale_ratio_x
|
| 456 |
+
new_y2 = y2 * scale_ratio_y
|
| 457 |
+
current_transformed_bboxes.append([new_x1, new_y1, new_x2, new_y2])
|
| 458 |
+
transformed_bbox_lists.append(current_transformed_bboxes)
|
| 459 |
+
|
| 460 |
+
# If original input was a single image (not list), unpack.
|
| 461 |
+
if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4:
|
| 462 |
+
return resized_images, transformed_bbox_lists[0]
|
| 463 |
+
else:
|
| 464 |
+
return resized_images, transformed_bbox_lists
|
| 465 |
+
|
| 466 |
+
def make_message_context(tokenizer, message, chat_format="chatml"):
|
| 467 |
+
"""
|
| 468 |
+
Given a message dict, construct the prompt, tokenized context tokens, image URLs, and bbox_list.
|
| 469 |
+
|
| 470 |
+
Handles both standard string 'content' and multi-part (list) content, appropriately placing image/region tokens.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
tokenizer: tokenizer object
|
| 474 |
+
message (dict): Contains role, content, and optionally bbox_list.
|
| 475 |
+
chat_format (str): Optionally select chat format (default 'chatml').
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
tuple: (inp, context_tokens, image_urls, bbox_list)
|
| 479 |
+
"""
|
| 480 |
+
image_urls = []
|
| 481 |
+
if chat_format == "chatml":
|
| 482 |
+
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
| 483 |
+
im_start_tokens = [151644]
|
| 484 |
+
im_end_tokens = [151645]
|
| 485 |
+
nl_tokens = tokenizer.encode("\n")
|
| 486 |
+
role = message["role"]
|
| 487 |
+
content = message["content"]
|
| 488 |
+
bbox_list = message.get("bbox_list", None)
|
| 489 |
+
|
| 490 |
+
if role == "system":
|
| 491 |
+
inp = f"{im_start}{role}\n{content}{im_end}\n"
|
| 492 |
+
context_tokens = tokenizer.encode(
|
| 493 |
+
role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, allowed_special=set())
|
| 494 |
+
context_tokens = im_start_tokens + context_tokens + im_end_tokens
|
| 495 |
+
|
| 496 |
+
if role == "user":
|
| 497 |
+
if isinstance(content, str):
|
| 498 |
+
# Plain string message
|
| 499 |
+
inp = f"{im_start}{role}\n{content}{im_end}\n"
|
| 500 |
+
context_tokens = tokenizer.encode(
|
| 501 |
+
role, allowed_special=set()) + nl_tokens + tokenizer.encode(content,
|
| 502 |
+
allowed_special=set())
|
| 503 |
+
context_tokens = im_start_tokens + context_tokens + im_end_tokens
|
| 504 |
+
if isinstance(content, list):
|
| 505 |
+
# Multi-part message (text and image_url parts, maybe region tokens)
|
| 506 |
+
inp = f"{im_start}{role}\n"
|
| 507 |
+
image_count = 1
|
| 508 |
+
for message_part in content:
|
| 509 |
+
if message_part["type"] == "text":
|
| 510 |
+
inp += f"{message_part['text']}"
|
| 511 |
+
|
| 512 |
+
if message_part["type"] == "image_url":
|
| 513 |
+
# Insert special vision/image tokens, possibly region tokens
|
| 514 |
+
inp += DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN + '\n'
|
| 515 |
+
# If regions exist, add per-region special token.
|
| 516 |
+
if bbox_list and len(bbox_list) > 0:
|
| 517 |
+
for idx, bbox in enumerate(bbox_list):
|
| 518 |
+
inp += DEFAULT_REGION_TOKEN.replace('<i>', str(idx)) + DEFAULT_REGION_FEATURE_TOKEN
|
| 519 |
+
inp += '\n'
|
| 520 |
+
|
| 521 |
+
image_urls.append(message_part['image_url']['url'])
|
| 522 |
+
image_count += 1
|
| 523 |
+
inp += f"{im_end}\n"
|
| 524 |
+
|
| 525 |
+
# Choose tokenizer logic based on whether bbox (region) list exists
|
| 526 |
+
if bbox_list and len(bbox_list) > 0:
|
| 527 |
+
context_tokens = tokenizer_image_region_token(inp, tokenizer)
|
| 528 |
+
else:
|
| 529 |
+
context_tokens = tokenizer_image_token(inp, tokenizer, image_token_index=IMAGE_TOKEN_INDEX)
|
| 530 |
+
return inp, context_tokens, image_urls, bbox_list
|
| 531 |
+
|
| 532 |
+
def prepare_inputs(model_name, model, image_processors, tokenizer, messages, device="cuda", max_tokens=512, top_p=1.0, temperature=0.0, do_sample=False, image_size=None):
|
| 533 |
+
"""
|
| 534 |
+
Fully prepares keyword arguments for model.generate (and compatible API) from messages and model specs.
|
| 535 |
+
|
| 536 |
+
Handles prompt assembly, tokenization, image loading/preprocessing, region support, streaming, etc.
|
| 537 |
+
Supports specific tweak for Qwen2.5-VL style vision tokens.
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
model_name (str): Model identifier string.
|
| 541 |
+
model: Model/config object.
|
| 542 |
+
image_processors (tuple): (primary, auxiliary) image processors.
|
| 543 |
+
tokenizer: Tokenizer object.
|
| 544 |
+
messages (list): Multi-message input list (chat history).
|
| 545 |
+
device (str): Target (usually 'cuda' or 'cpu').
|
| 546 |
+
max_tokens, top_p, temperature, do_sample: Standard generation kwargs.
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
dict: ready-to-use argument dict for model.generate().
|
| 550 |
+
"""
|
| 551 |
+
# For Qwen2.5-VL, patch vision special tokens globally.
|
| 552 |
+
if 'qwen2.5-vl' in model_name.lower() or 'qwen2_5_vl' in model_name.lower():
|
| 553 |
+
global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 554 |
+
DEFAULT_IM_START_TOKEN = "<|vision_start|>"
|
| 555 |
+
DEFAULT_IM_END_TOKEN = "<|vision_end|>"
|
| 556 |
+
|
| 557 |
+
primary_image_processor, auxiliary_image_processor = image_processors
|
| 558 |
+
|
| 559 |
+
prompt = ""
|
| 560 |
+
input_tokens = []
|
| 561 |
+
image_urls = []
|
| 562 |
+
# Compose prompt and accumulate all components from provided messages
|
| 563 |
+
for message in messages:
|
| 564 |
+
inp, context_tokens, image_urls, bbox_list = make_message_context(tokenizer, message)
|
| 565 |
+
prompt += inp
|
| 566 |
+
input_tokens.extend(context_tokens)
|
| 567 |
+
|
| 568 |
+
# Ensure a system prompt at start, if not already present.
|
| 569 |
+
if "system" not in prompt:
|
| 570 |
+
system_content = "system\nYou are a helpful assistant."
|
| 571 |
+
system_prompt = "<|im_start|>" + system_content + "<|im_end|>" + "\n"
|
| 572 |
+
prompt = system_prompt + prompt
|
| 573 |
+
system_tokens = [151644] + tokenizer(system_content).input_ids + [151645] + tokenizer("\n").input_ids
|
| 574 |
+
input_tokens = system_tokens + input_tokens
|
| 575 |
+
|
| 576 |
+
# Ensure prompt ends with assistant's turn.
|
| 577 |
+
if not prompt.endswith("<|im_start|>assistant"):
|
| 578 |
+
last_assistant_prompt = "<|im_start|>" + "assistant" + "\n"
|
| 579 |
+
prompt += last_assistant_prompt
|
| 580 |
+
# last_assistant_tokens = [6] + self.tokenizer("assistant\n").input_ids
|
| 581 |
+
last_assistant_tokens = [151644] + tokenizer("assistant\n").input_ids
|
| 582 |
+
input_tokens.extend(last_assistant_tokens)
|
| 583 |
+
|
| 584 |
+
primary_images_tensor = None
|
| 585 |
+
auxiliary_images_tensor = None
|
| 586 |
+
primary_image_grid_thws = None
|
| 587 |
+
if image_urls:
|
| 588 |
+
# Load images, resize them, and update bbox_list downstream
|
| 589 |
+
images = [load_image(i) for i in image_urls]
|
| 590 |
+
if image_size is not None:
|
| 591 |
+
images, bbox_list = resize_shortest_edge_images_and_bboxes(images, bbox_list, candidate_sizes=[image_size], max_size=2048)
|
| 592 |
+
else:
|
| 593 |
+
images, bbox_list = resize_shortest_edge_images_and_bboxes(images, bbox_list, max_size=2048)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
# When region-indexed tokens are enabled
|
| 597 |
+
if getattr(model.config, 'mm_use_region_index_token', False):
|
| 598 |
+
origin_image_size = [image.size for image in images]
|
| 599 |
+
aux_images = images.copy()
|
| 600 |
+
auxiliary_images_tensor = [auxiliary_image_processor.preprocess(i, return_tensors='pt')['pixel_values'][0].to(device) for i in aux_images]
|
| 601 |
+
|
| 602 |
+
if bbox_list and len(bbox_list) > 0:
|
| 603 |
+
# Limit number of bbox (for computational constraints, etc.)
|
| 604 |
+
bbox_list = bbox_list[:100]
|
| 605 |
+
resize_h, resize_w = auxiliary_images_tensor[0].shape[-2:]
|
| 606 |
+
original_w, original_h = origin_image_size[0]
|
| 607 |
+
# Adjust bbox to match resized images (post pre-processing)
|
| 608 |
+
bbox_list = adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w)
|
| 609 |
+
bbox_list = [torch.tensor(bbox_list)]
|
| 610 |
+
else:
|
| 611 |
+
bbox_list = None
|
| 612 |
+
else:
|
| 613 |
+
auxiliary_images_tensor = None
|
| 614 |
+
|
| 615 |
+
# Preprocess primary images for main vision model branch
|
| 616 |
+
primary_images = []
|
| 617 |
+
primary_image_grid_thws = []
|
| 618 |
+
for im in images:
|
| 619 |
+
processed_data = primary_image_processor.preprocess(im, return_tensors="pt")
|
| 620 |
+
image_i = processed_data['pixel_values']
|
| 621 |
+
image_grid_thw_i = processed_data['image_grid_thw']
|
| 622 |
+
primary_images.append(image_i)
|
| 623 |
+
primary_image_grid_thws.append(image_grid_thw_i)
|
| 624 |
+
primary_images_tensor = [image_i.to(device) for image_i in primary_images]
|
| 625 |
+
|
| 626 |
+
# For Qwen-style, force specific end-token as stopping criterion
|
| 627 |
+
if "qwen" in model_name.lower():
|
| 628 |
+
input_ids = torch.tensor([input_tokens]).to(device)
|
| 629 |
+
keywords = ["<|im_end|>"]
|
| 630 |
+
|
| 631 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 632 |
+
streamer = TextStreamer(
|
| 633 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# Default: greedy decoding if temperature=0. Else: enable sampling.
|
| 637 |
+
if temperature == 0.0:
|
| 638 |
+
do_sample = False
|
| 639 |
+
else:
|
| 640 |
+
do_sample = True
|
| 641 |
+
|
| 642 |
+
print("question:================\n", prompt, "\n=================")
|
| 643 |
+
# print("input ids:========", input_ids, "========")
|
| 644 |
+
generation_kwargs = dict(
|
| 645 |
+
inputs=input_ids,
|
| 646 |
+
images=primary_images_tensor,
|
| 647 |
+
images_aux=auxiliary_images_tensor,
|
| 648 |
+
image_grid_thws=primary_image_grid_thws,
|
| 649 |
+
bbox_list=bbox_list,
|
| 650 |
+
do_sample=do_sample,
|
| 651 |
+
temperature=temperature,
|
| 652 |
+
max_new_tokens=max_tokens,
|
| 653 |
+
streamer=streamer,
|
| 654 |
+
top_p=top_p,
|
| 655 |
+
use_cache=True,
|
| 656 |
+
stopping_criteria=[stopping_criteria],
|
| 657 |
+
pad_token_id=tokenizer.pad_token_id
|
| 658 |
+
)
|
| 659 |
+
return generation_kwargs
|
| 660 |
+
|
vlm_fo1/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .language_model.omchat_qwen2_5_vl import OmChatQwen25VLForCausalLM, OmChatQwen25VLConfig
|
vlm_fo1/model/builder.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
+
import torch
|
| 3 |
+
from vlm_fo1.model import *
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"):
|
| 9 |
+
"""
|
| 10 |
+
Loads a pretrained model along with its vision towers (and associated image processors).
|
| 11 |
+
This function supports loading in 8bit/4bit precision and explicit device placement.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model_path (str): Path to the pretrained model directory.
|
| 15 |
+
load_8bit (bool): Whether to load the model in 8bit mode.
|
| 16 |
+
load_4bit (bool): Whether to load the model in 4bit mode.
|
| 17 |
+
device (str): Device to load model onto, e.g., "cuda" or "cpu".
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
tuple: (tokenizer, model, image_processor)
|
| 21 |
+
"""
|
| 22 |
+
kwargs = {"device_map": device}
|
| 23 |
+
|
| 24 |
+
# Set model loading parameters for quantization or floating point
|
| 25 |
+
if load_8bit:
|
| 26 |
+
kwargs['load_in_8bit'] = True
|
| 27 |
+
elif load_4bit:
|
| 28 |
+
kwargs['load_in_4bit'] = True
|
| 29 |
+
else:
|
| 30 |
+
kwargs['torch_dtype'] = torch.bfloat16
|
| 31 |
+
|
| 32 |
+
# print(model_path)
|
| 33 |
+
|
| 34 |
+
# Only proceed for vlm-fo1 models
|
| 35 |
+
if 'vlm-fo1' in model_path.lower():
|
| 36 |
+
# Load tokenizer (slow tokenizer enforced)
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 38 |
+
# If this is the Qwen2.5-VL variant, load with additional kwargs
|
| 39 |
+
if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
|
| 40 |
+
model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained(
|
| 41 |
+
model_path,
|
| 42 |
+
low_cpu_mem_usage=True,
|
| 43 |
+
output_loading_info=True,
|
| 44 |
+
attn_implementation="flash_attention_2",
|
| 45 |
+
**kwargs
|
| 46 |
+
)
|
| 47 |
+
# print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}')
|
| 48 |
+
# (For other variants of vlm-fo1, model loading detail may need additional condition.)
|
| 49 |
+
|
| 50 |
+
if 'vlm-fo1' in model_path.lower():
|
| 51 |
+
# --- Vision Tower Loading ---
|
| 52 |
+
# Load the main vision tower weights from model_path if it is not yet loaded
|
| 53 |
+
primary_vision_tower = model.get_vision_tower()
|
| 54 |
+
if primary_vision_tower and not primary_vision_tower.is_loaded:
|
| 55 |
+
primary_vision_tower.load_model(model_path=model_path, is_train=False)
|
| 56 |
+
primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype
|
| 57 |
+
|
| 58 |
+
# Grab primary image processor from vision tower, if present
|
| 59 |
+
if primary_vision_tower:
|
| 60 |
+
primary_image_processor = primary_vision_tower.image_processor
|
| 61 |
+
|
| 62 |
+
# --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) ---
|
| 63 |
+
if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
|
| 64 |
+
try:
|
| 65 |
+
aux_image_size = model.config.aux_image_size
|
| 66 |
+
except Exception:
|
| 67 |
+
# If aux_image_size is missing from config fallback to 768
|
| 68 |
+
aux_image_size = 768
|
| 69 |
+
|
| 70 |
+
aux_image_aspect_ratio = model.config.aux_image_aspect_ratio
|
| 71 |
+
aux_vision_tower = model.get_vision_tower_aux()
|
| 72 |
+
# Only load if not already loaded
|
| 73 |
+
if aux_vision_tower and not aux_vision_tower.is_loaded:
|
| 74 |
+
aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio)
|
| 75 |
+
aux_vision_tower.to(device=device, dtype=torch.bfloat16)
|
| 76 |
+
|
| 77 |
+
# Get auxiliary image processor if there is an aux vision tower
|
| 78 |
+
if aux_vision_tower:
|
| 79 |
+
aux_image_processor = aux_vision_tower.image_processor
|
| 80 |
+
else:
|
| 81 |
+
image_processor = None # Set to None if there is no auxiliary vision tower
|
| 82 |
+
|
| 83 |
+
# image_processor returned as a tuple of (primary, aux)
|
| 84 |
+
image_processor = (primary_image_processor, aux_image_processor)
|
| 85 |
+
|
| 86 |
+
# Set model to eval mode and move to correct device before returning
|
| 87 |
+
model.eval()
|
| 88 |
+
model.to(device=device, dtype=torch.bfloat16)
|
| 89 |
+
return tokenizer, model, image_processor
|
vlm_fo1/model/language_model/omchat_qwen2_5_vl.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from transformers import Qwen2_5_VLConfig, AutoConfig, AutoModelForCausalLM
|
| 7 |
+
from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
|
| 8 |
+
from vlm_fo1.model.multimodal_encoder.qwen2_5_vl_encoder import Qwen2_5_VlVisionTower
|
| 9 |
+
from vlm_fo1.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_REGION_INDEX, QWEN2_5_VL_IMAGE_TOKEN, QWEN2_5_VL_IMAGE_TOKEN_INDEX
|
| 10 |
+
|
| 11 |
+
from ..omchat_arch import OmChatMetaModel, OmChatMetaForCausalLM
|
| 12 |
+
|
| 13 |
+
# Custom config which extends Qwen2_5_VLConfig for OmChat multimodal model
|
| 14 |
+
class OmChatQwen25VLConfig(Qwen2_5_VLConfig):
|
| 15 |
+
model_type = "omchat_qwen2_5_vl"
|
| 16 |
+
rotary_type = "normal_rotary"
|
| 17 |
+
multi_scale_im = None
|
| 18 |
+
vision_tower_aux = None
|
| 19 |
+
|
| 20 |
+
# Core model definition: inherits from OmChat and Qwen multimodal base
|
| 21 |
+
class OmChatQwen25VLModel(OmChatMetaModel, Qwen2_5_VLModel):
|
| 22 |
+
config_class = OmChatQwen25VLConfig
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: Qwen2_5_VLConfig):
|
| 25 |
+
super(OmChatQwen25VLModel, self).__init__(config)
|
| 26 |
+
|
| 27 |
+
# Main class for multimodal CausalLM
|
| 28 |
+
class OmChatQwen25VLForCausalLM(Qwen2_5_VLForConditionalGeneration, OmChatMetaForCausalLM):
|
| 29 |
+
config_class = OmChatQwen25VLConfig
|
| 30 |
+
|
| 31 |
+
def __init__(self, config, delay_load=True):
|
| 32 |
+
# Ensure config has delay_load property
|
| 33 |
+
if not hasattr(config, 'delay_load'):
|
| 34 |
+
config.delay_load = delay_load
|
| 35 |
+
super(Qwen2_5_VLForConditionalGeneration, self).__init__(config)
|
| 36 |
+
self.model = OmChatQwen25VLModel(config)
|
| 37 |
+
self.vocab_size = config.vocab_size
|
| 38 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 39 |
+
self.rope_deltas = None # cache rope_deltas here
|
| 40 |
+
|
| 41 |
+
self.post_init()
|
| 42 |
+
|
| 43 |
+
# Encode input images into feature representations
|
| 44 |
+
def encode_images(self, images, images_grid_thw=None):
|
| 45 |
+
# If vision_tower is Qwen2.5-specific, use its custom forward signature
|
| 46 |
+
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
|
| 47 |
+
image_features = self.get_model().get_vision_tower()(images, images_grid_thw)
|
| 48 |
+
image_features, image_grid_thws, multi_level_features = image_features
|
| 49 |
+
# If multiple images, handle concatenation
|
| 50 |
+
if type(image_features) is list:
|
| 51 |
+
# List has items of shape (1, seq_len, dim)
|
| 52 |
+
token_length_list = [i.shape[1] for i in image_features]
|
| 53 |
+
image_features = torch.cat(image_features, dim=1) # Concatenate to (1, total_seq_len, dim)
|
| 54 |
+
else:
|
| 55 |
+
image_features = self.get_model().get_vision_tower()(images)
|
| 56 |
+
image_grid_thws = None
|
| 57 |
+
multi_level_features = None
|
| 58 |
+
|
| 59 |
+
image_features = self.get_model().mm_projector(image_features)
|
| 60 |
+
|
| 61 |
+
# Split concatenated image features back by original lengths (for multi-image case)
|
| 62 |
+
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
|
| 63 |
+
start = 0
|
| 64 |
+
new_image_features = []
|
| 65 |
+
# Split according to token_length_list
|
| 66 |
+
for length in token_length_list:
|
| 67 |
+
end = start + length
|
| 68 |
+
new_image_features.append(image_features[:, start:end, :].squeeze(0))
|
| 69 |
+
start = end
|
| 70 |
+
image_features = new_image_features
|
| 71 |
+
|
| 72 |
+
return image_features, image_grid_thws, multi_level_features
|
| 73 |
+
|
| 74 |
+
# Encode region regions (bounding boxes) into features, optionally using auxiliary vision tower
|
| 75 |
+
def encode_regions(self, images, bbox_list, vt_multi_level_features=None, vt_images_size=None):
|
| 76 |
+
aux_image_features_list = self.get_model().get_vision_tower_aux()(images)
|
| 77 |
+
region_features = []
|
| 78 |
+
if getattr(self.config, "mm_use_vision_tower_region_feature", False):
|
| 79 |
+
image_features_list = vt_multi_level_features
|
| 80 |
+
for batch_idx, (image_features, aux_image_features) in enumerate(zip(image_features_list, aux_image_features_list)):
|
| 81 |
+
|
| 82 |
+
if getattr(self.config, "mm_use_simpleFPN_for_vt", False):
|
| 83 |
+
multilevel_visual_feats = image_features[-1]
|
| 84 |
+
else:
|
| 85 |
+
multilevel_visual_feats = image_features
|
| 86 |
+
multilevel_aux_visual_feats = aux_image_features["image_features"]
|
| 87 |
+
boxes = bbox_list[batch_idx]
|
| 88 |
+
|
| 89 |
+
# If no boxes provided, use dummy box (covers tiny region)
|
| 90 |
+
if boxes is None or len(boxes) == 0:
|
| 91 |
+
boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_aux_visual_feats[0].device, dtype=torch.float32)
|
| 92 |
+
|
| 93 |
+
boxes = boxes.to(torch.float32).to(multilevel_aux_visual_feats[0].device)
|
| 94 |
+
current_image_height, current_image_width = images[batch_idx].shape[-2:]
|
| 95 |
+
original_height, original_width = vt_images_size[batch_idx]
|
| 96 |
+
# Scale bounding boxes from original image size to processed size
|
| 97 |
+
scale_height = original_height / current_image_height
|
| 98 |
+
scale_width = original_width / current_image_width
|
| 99 |
+
vt_boxes = boxes * torch.tensor([scale_width, scale_height, scale_width, scale_height], device=boxes.device)
|
| 100 |
+
|
| 101 |
+
extracted_region_feat = self.get_model().object_vp_extractor(
|
| 102 |
+
aux_multi_level_features=multilevel_aux_visual_feats,
|
| 103 |
+
vt_multi_level_features=multilevel_visual_feats,
|
| 104 |
+
aux_boxes=[boxes],
|
| 105 |
+
vt_boxes=[vt_boxes]
|
| 106 |
+
).squeeze(0).to(multilevel_aux_visual_feats[0].dtype)
|
| 107 |
+
region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2048]
|
| 108 |
+
region_features.append(region_feat)
|
| 109 |
+
else:
|
| 110 |
+
# Extract region features only from auxiliary vision tower
|
| 111 |
+
for batch_idx, image_features in enumerate(aux_image_features_list):
|
| 112 |
+
multilevel_visual_feats = image_features["image_features"]
|
| 113 |
+
last_feat = image_features["last_feat"]
|
| 114 |
+
boxes = bbox_list[batch_idx]
|
| 115 |
+
|
| 116 |
+
if boxes is None or len(boxes) == 0:
|
| 117 |
+
boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_visual_feats[0].device, dtype=torch.float32)
|
| 118 |
+
|
| 119 |
+
multi_level_aux_features = multilevel_visual_feats
|
| 120 |
+
boxes = boxes.to(torch.float32).to(multi_level_aux_features[0].device)
|
| 121 |
+
extracted_region_feat = self.get_model().object_vp_extractor(
|
| 122 |
+
multi_level_aux_features,
|
| 123 |
+
[boxes],
|
| 124 |
+
).squeeze(0).to(multi_level_aux_features[0].dtype)
|
| 125 |
+
region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2880]
|
| 126 |
+
region_features.append(region_feat)
|
| 127 |
+
|
| 128 |
+
return region_features
|
| 129 |
+
|
| 130 |
+
def get_model(self):
|
| 131 |
+
# Getter for model. Used to access backbone/model internals.
|
| 132 |
+
return self.model
|
| 133 |
+
|
| 134 |
+
# Convert sequence of input_ids/labels/images/boxes to multimodal embedding and associated masks/ids for transformer input.
|
| 135 |
+
def prepare_inputs_labels_for_qwen2_5_vl_multimodal(
|
| 136 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux=None, bbox_list=None, image_grid_thws=None
|
| 137 |
+
):
|
| 138 |
+
# ========================== Above this line, input parsing and batching =============================
|
| 139 |
+
vision_tower = self.get_vision_tower()
|
| 140 |
+
video_tower = self.get_video_tower()
|
| 141 |
+
vision_tower_aux = self.get_vision_tower_aux()
|
| 142 |
+
# Fast-path for non-multimodal case or first step in generation (i.e. only one token in input)
|
| 143 |
+
if (vision_tower is None and video_tower is None) or images is None or input_ids.shape[1] == 1:
|
| 144 |
+
if past_key_values is not None and (vision_tower is not None or video_tower is not None) and images is not None and input_ids.shape[1] == 1:
|
| 145 |
+
|
| 146 |
+
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
| 147 |
+
attention_mask = torch.cat((attention_mask, torch.ones(
|
| 148 |
+
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
| 149 |
+
dtype=attention_mask.dtype,
|
| 150 |
+
device=attention_mask.device
|
| 151 |
+
)), dim=1)
|
| 152 |
+
|
| 153 |
+
position_ids=None
|
| 154 |
+
cache_position = torch.tensor([target_shape - 1],device=attention_mask.device)
|
| 155 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, cache_position
|
| 156 |
+
|
| 157 |
+
# Indices for images (3D or 2D tensors) and videos (4D tensors)
|
| 158 |
+
image_idx = [idx for idx, img in enumerate(images) if img.ndim == 3 or img.ndim == 2]
|
| 159 |
+
is_all_image = len(image_idx) == len(images)
|
| 160 |
+
video_idx = [idx for idx, vid in enumerate(images) if vid.ndim == 4]
|
| 161 |
+
|
| 162 |
+
# Stack image and video tensors accordingly for mini-batch processing
|
| 163 |
+
if isinstance(vision_tower, Qwen2_5_VlVisionTower):
|
| 164 |
+
images_minibatch = [images[idx] for idx in image_idx] if len(image_idx) > 0 else [] # list of [c,h,w], can have variable shapes
|
| 165 |
+
else:
|
| 166 |
+
images_minibatch = torch.stack([images[idx] for idx in image_idx]) if len(image_idx) > 0 else [] # tensor [mini_b, c, h, w]
|
| 167 |
+
videos_minibatch = torch.stack([images[idx] for idx in video_idx]) if len(video_idx) > 0 else [] # tensor [mini_b, c, t, h, w]
|
| 168 |
+
|
| 169 |
+
# Auxiliary batch for region encoding, if relevant
|
| 170 |
+
if vision_tower_aux is not None and images_aux is not None:
|
| 171 |
+
images_minibatch_aux = [images_aux[idx].unsqueeze(0) for idx in image_idx] if len(image_idx) > 0 else [] # list of [1, c, h, w]
|
| 172 |
+
|
| 173 |
+
# tmp_image_features will be indexed to scatter extracted image/video features into original batch positions
|
| 174 |
+
tmp_image_features = [None] * (len(image_idx) + len(video_idx))
|
| 175 |
+
if getattr(images_minibatch, 'ndim', 0) == 4 or (type(images_minibatch) is list and len(images_minibatch) > 0): # batch consists of images, [mini_b, c, h, w]
|
| 176 |
+
if vision_tower is not None:
|
| 177 |
+
image_features_minibatch, image_grid_thws_minibatch, vt_multi_level_features_minibatch = self.encode_images(images_minibatch, image_grid_thws) # [mini_b, l, c]
|
| 178 |
+
else:
|
| 179 |
+
image_features_minibatch = torch.randn(1).to(self.device) # dummy feature for video-only training under tuning
|
| 180 |
+
|
| 181 |
+
# Map extracted image features back to their places in the original batch
|
| 182 |
+
for i, pos in enumerate(image_idx):
|
| 183 |
+
tmp_image_features[pos] = image_features_minibatch[i]
|
| 184 |
+
|
| 185 |
+
# Handle auxiliary region features if enabled and boxes provided
|
| 186 |
+
if vision_tower_aux is not None and bbox_list is not None and len(bbox_list) > 0:
|
| 187 |
+
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
|
| 188 |
+
patch_size = self.get_model().get_vision_tower().config.patch_size
|
| 189 |
+
vt_images_size_minibatch = [im_grid_thw[0][-2:]*patch_size for im_grid_thw in image_grid_thws]
|
| 190 |
+
region_features = self.encode_regions(images_minibatch_aux, bbox_list, vt_multi_level_features_minibatch, vt_images_size_minibatch) # [mini_b, l, c]
|
| 191 |
+
else:
|
| 192 |
+
region_features = None
|
| 193 |
+
|
| 194 |
+
# Same as above, but for video features if any
|
| 195 |
+
if getattr(videos_minibatch, 'ndim', 0) == 5: # batch consists of videos, [mini_b, c, t, h, w]
|
| 196 |
+
video_features_minibatch = self.encode_videos(videos_minibatch) # fake list [mini_b, t, l, c]
|
| 197 |
+
for i, pos in enumerate(video_idx):
|
| 198 |
+
tmp_image_features[pos] = video_features_minibatch[i]
|
| 199 |
+
|
| 200 |
+
# Flatten image feature slot list to proper order for current batch
|
| 201 |
+
new_tmp = []
|
| 202 |
+
for image in tmp_image_features:
|
| 203 |
+
# If multi-image per item, flatten out
|
| 204 |
+
if isinstance(image, list):
|
| 205 |
+
t = len(image)
|
| 206 |
+
for i in range(t):
|
| 207 |
+
new_tmp.append(image[i])
|
| 208 |
+
else:
|
| 209 |
+
new_tmp.append(image)
|
| 210 |
+
image_features = new_tmp
|
| 211 |
+
|
| 212 |
+
# =========================== Now, build multimodal input & target sequences =========================
|
| 213 |
+
|
| 214 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| 215 |
+
raise NotImplementedError
|
| 216 |
+
|
| 217 |
+
_labels = labels
|
| 218 |
+
_position_ids = position_ids
|
| 219 |
+
_attention_mask = attention_mask
|
| 220 |
+
|
| 221 |
+
# Default construction of masks etc.
|
| 222 |
+
if attention_mask is None:
|
| 223 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 224 |
+
else:
|
| 225 |
+
attention_mask = attention_mask.bool()
|
| 226 |
+
if position_ids is None:
|
| 227 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
| 228 |
+
if labels is None:
|
| 229 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 230 |
+
|
| 231 |
+
# For each batch item, strip padded tokens based on attention_mask
|
| 232 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
| 233 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 234 |
+
|
| 235 |
+
# If neither region auxiliary nor bboxes present: process classic image-text input
|
| 236 |
+
if vision_tower_aux is None and (bbox_list is None or all(x is None for x in bbox_list)):
|
| 237 |
+
new_input_embeds = []
|
| 238 |
+
new_labels = []
|
| 239 |
+
new_input_ids = []
|
| 240 |
+
cur_image_idx = 0
|
| 241 |
+
image_nums_in_batch = []
|
| 242 |
+
|
| 243 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 244 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
| 245 |
+
image_nums_in_batch.append(num_images)
|
| 246 |
+
# If there are no image markers, just get text features
|
| 247 |
+
if num_images == 0:
|
| 248 |
+
cur_image_features = image_features[cur_image_idx]
|
| 249 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
| 250 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
| 251 |
+
new_input_embeds.append(cur_input_embeds)
|
| 252 |
+
new_labels.append(labels[batch_idx])
|
| 253 |
+
new_input_ids.append(cur_input_ids)
|
| 254 |
+
cur_image_idx += 1
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
# Split on image token indices: replace them with image features after conversion
|
| 258 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
| 259 |
+
cur_input_ids_noim = []
|
| 260 |
+
cur_labels = labels[batch_idx]
|
| 261 |
+
cur_labels_noim = []
|
| 262 |
+
for i in range(len(image_token_indices) - 1):
|
| 263 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
| 264 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
| 265 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 266 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
| 267 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
| 268 |
+
|
| 269 |
+
cur_new_input_embeds = []
|
| 270 |
+
cur_new_labels = []
|
| 271 |
+
cur_new_input_ids = []
|
| 272 |
+
for i in range(num_images + 1):
|
| 273 |
+
# Interleave text and image features
|
| 274 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
| 275 |
+
cur_new_labels.append(cur_labels_noim[i])
|
| 276 |
+
cur_new_input_ids.append(cur_input_ids_noim[i])
|
| 277 |
+
if i < num_images:
|
| 278 |
+
cur_image_features = image_features[cur_image_idx].to(self.device)
|
| 279 |
+
cur_image_idx += 1
|
| 280 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 281 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 282 |
+
cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 283 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
| 284 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
| 285 |
+
cur_new_input_ids = torch.cat(cur_new_input_ids)
|
| 286 |
+
|
| 287 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 288 |
+
new_labels.append(cur_new_labels)
|
| 289 |
+
new_input_ids.append(cur_new_input_ids)
|
| 290 |
+
# If region markers or region features enabled in config
|
| 291 |
+
else:
|
| 292 |
+
new_input_embeds = []
|
| 293 |
+
new_labels = []
|
| 294 |
+
new_input_ids = []
|
| 295 |
+
cur_image_idx = 0
|
| 296 |
+
image_nums_in_batch = []
|
| 297 |
+
|
| 298 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 299 |
+
cur_region_idx = 0
|
| 300 |
+
# Detect image and region special token counts
|
| 301 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
| 302 |
+
num_regions = (cur_input_ids == DEFAULT_REGION_INDEX).sum() if DEFAULT_REGION_INDEX in cur_input_ids else 0
|
| 303 |
+
image_nums_in_batch.append(num_images)
|
| 304 |
+
|
| 305 |
+
# If no markers, just do text embedding for this item
|
| 306 |
+
if num_images == 0 and num_regions == 0:
|
| 307 |
+
cur_image_features = image_features[cur_image_idx]
|
| 308 |
+
cur_region_features = region_features[cur_region_idx]
|
| 309 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
| 310 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_region_features[0:0]], dim=0)
|
| 311 |
+
new_input_embeds.append(cur_input_embeds)
|
| 312 |
+
new_labels.append(labels[batch_idx])
|
| 313 |
+
new_input_ids.append(cur_input_ids)
|
| 314 |
+
cur_image_idx += 1
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
# Get all special marker indices (image/region)
|
| 318 |
+
image_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
|
| 319 |
+
region_indices = torch.where(cur_input_ids == DEFAULT_REGION_INDEX)[0].tolist() if num_regions > 0 else []
|
| 320 |
+
all_special_indices = sorted([-1] + image_indices + region_indices + [cur_input_ids.shape[0]])
|
| 321 |
+
|
| 322 |
+
# Split out plain text chunks between special markers
|
| 323 |
+
cur_input_ids_segments = []
|
| 324 |
+
cur_labels = labels[batch_idx]
|
| 325 |
+
cur_labels_segments = []
|
| 326 |
+
|
| 327 |
+
for i in range(len(all_special_indices) - 1):
|
| 328 |
+
cur_input_ids_segments.append(cur_input_ids[all_special_indices[i]+1:all_special_indices[i+1]])
|
| 329 |
+
cur_labels_segments.append(cur_labels[all_special_indices[i]+1:all_special_indices[i+1]])
|
| 330 |
+
|
| 331 |
+
# Project text ids to word embeddings
|
| 332 |
+
split_sizes = [x.shape[0] for x in cur_labels_segments]
|
| 333 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_segments))
|
| 334 |
+
if num_regions == 0 and vision_tower_aux is not None and region_features is not None:
|
| 335 |
+
cur_region_features = region_features[cur_region_idx]
|
| 336 |
+
temp_input_embeds = torch.cat([cur_input_embeds, cur_region_features[0:0]], dim=0)
|
| 337 |
+
cur_input_embeds = temp_input_embeds
|
| 338 |
+
|
| 339 |
+
cur_input_embeds_segments = torch.split(cur_input_embeds, split_sizes, dim=0)
|
| 340 |
+
|
| 341 |
+
# Reassemble text and image/region segments in order
|
| 342 |
+
cur_new_input_embeds = []
|
| 343 |
+
cur_new_labels = []
|
| 344 |
+
cur_new_input_ids = []
|
| 345 |
+
|
| 346 |
+
for i in range(len(all_special_indices) - 1):
|
| 347 |
+
# Insert current text segment
|
| 348 |
+
cur_new_input_embeds.append(cur_input_embeds_segments[i])
|
| 349 |
+
cur_new_labels.append(cur_labels_segments[i])
|
| 350 |
+
cur_new_input_ids.append(cur_input_ids_segments[i])
|
| 351 |
+
# If next is image, insert feature representation
|
| 352 |
+
if all_special_indices[i+1] in image_indices:
|
| 353 |
+
cur_image_features = image_features[cur_image_idx].to(self.device)
|
| 354 |
+
cur_image_idx += 1
|
| 355 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 356 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 357 |
+
cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 358 |
+
|
| 359 |
+
# If next is region token, insert extracted region features
|
| 360 |
+
elif all_special_indices[i+1] in region_indices:
|
| 361 |
+
cur_region_features = region_features[batch_idx][cur_region_idx].to(self.device).unsqueeze(0)
|
| 362 |
+
cur_region_idx += 1
|
| 363 |
+
cur_new_input_embeds.append(cur_region_features)
|
| 364 |
+
|
| 365 |
+
cur_new_labels.append(torch.full((cur_region_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 366 |
+
cur_new_input_ids.append(torch.full((cur_region_features.shape[0],), DEFAULT_REGION_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
| 367 |
+
# Combine for this batch item
|
| 368 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
| 369 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
| 370 |
+
cur_new_input_ids = torch.cat(cur_new_input_ids)
|
| 371 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 372 |
+
new_labels.append(cur_new_labels)
|
| 373 |
+
new_input_ids.append(cur_new_input_ids)
|
| 374 |
+
# Truncate sequences to maximum model length, if image+region tokens caused overflow
|
| 375 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
| 376 |
+
if tokenizer_model_max_length is not None:
|
| 377 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
| 378 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
| 379 |
+
|
| 380 |
+
# Pad sequences in the batch to same length; compute batch masks
|
| 381 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
| 382 |
+
batch_size = len(new_input_embeds)
|
| 383 |
+
|
| 384 |
+
new_input_embeds_padded = []
|
| 385 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
| 386 |
+
new_input_ids_padded = torch.full((batch_size, max_len), self.config.bos_token_id, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
|
| 387 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 388 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
| 389 |
+
|
| 390 |
+
# Left or right padding as per config; fill padded tensors
|
| 391 |
+
for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)):
|
| 392 |
+
cur_len = cur_new_embed.shape[0]
|
| 393 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
| 394 |
+
# Left pad: add zeros before text tokens/features
|
| 395 |
+
new_input_embeds_padded.append(torch.cat((
|
| 396 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
| 397 |
+
cur_new_embed
|
| 398 |
+
), dim=0))
|
| 399 |
+
if cur_len > 0:
|
| 400 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
| 401 |
+
attention_mask[i, -cur_len:] = True
|
| 402 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 403 |
+
else:
|
| 404 |
+
# Right pad: add zeros after text tokens/features
|
| 405 |
+
new_input_embeds_padded.append(torch.cat((
|
| 406 |
+
cur_new_embed,
|
| 407 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
| 408 |
+
), dim=0))
|
| 409 |
+
if cur_len > 0:
|
| 410 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
| 411 |
+
new_input_ids_padded[i, :cur_len] = cur_new_input_ids
|
| 412 |
+
attention_mask[i, :cur_len] = True
|
| 413 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 414 |
+
|
| 415 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
| 416 |
+
new_input_ids = new_input_ids_padded
|
| 417 |
+
|
| 418 |
+
# Only set new_labels if original labels were not None
|
| 419 |
+
if _labels is None:
|
| 420 |
+
new_labels = None
|
| 421 |
+
else:
|
| 422 |
+
new_labels = new_labels_padded
|
| 423 |
+
|
| 424 |
+
# Similarly handle provided attention_mask/position_ids overrides
|
| 425 |
+
if _attention_mask is None:
|
| 426 |
+
attention_mask = None
|
| 427 |
+
else:
|
| 428 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
| 429 |
+
|
| 430 |
+
if _position_ids is None:
|
| 431 |
+
position_ids = None
|
| 432 |
+
|
| 433 |
+
# For Qwen2.5 vision towers, use and concatenate image_grid_thws for positional computations
|
| 434 |
+
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
|
| 435 |
+
image_grid_thws = []
|
| 436 |
+
cur_image_idx = 0
|
| 437 |
+
for num_images in image_nums_in_batch:
|
| 438 |
+
if num_images == 0:
|
| 439 |
+
cur_image_idx += 1
|
| 440 |
+
continue
|
| 441 |
+
image_grid_thws += image_grid_thws_minibatch[cur_image_idx:cur_image_idx+num_images]
|
| 442 |
+
cur_image_idx += num_images
|
| 443 |
+
|
| 444 |
+
if len(image_grid_thws) > 0:
|
| 445 |
+
image_grid_thws = torch.cat(image_grid_thws, dim=0)
|
| 446 |
+
else:
|
| 447 |
+
image_grid_thws = None
|
| 448 |
+
|
| 449 |
+
rope_index_kwargs = {
|
| 450 |
+
"input_ids": new_input_ids,
|
| 451 |
+
"image_grid_thw": image_grid_thws,
|
| 452 |
+
"video_grid_thw": None,
|
| 453 |
+
"attention_mask": attention_mask,
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
# Compute new position_ids and rope_deltas for transformer (for rotary embeddings)
|
| 457 |
+
position_ids, rope_deltas = self.get_rope_index(**rope_index_kwargs)
|
| 458 |
+
cache_position = torch.arange(new_input_embeds.shape[1], device=new_input_embeds.device)
|
| 459 |
+
else:
|
| 460 |
+
rope_deltas = None
|
| 461 |
+
cache_position = None
|
| 462 |
+
# Final output is a tuple mimicking HuggingFace prepare_inputs_for_generation return
|
| 463 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, rope_deltas, cache_position
|
| 464 |
+
|
| 465 |
+
# Patch forward() of HF CausalLM to allow multimodal embedding with images/regions
|
| 466 |
+
def forward(
|
| 467 |
+
self,
|
| 468 |
+
input_ids: torch.LongTensor = None,
|
| 469 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 470 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 471 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 472 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 473 |
+
labels: Optional[torch.LongTensor] = None,
|
| 474 |
+
use_cache: Optional[bool] = None,
|
| 475 |
+
output_attentions: Optional[bool] = None,
|
| 476 |
+
output_hidden_states: Optional[bool] = None,
|
| 477 |
+
return_dict: Optional[bool] = None,
|
| 478 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 479 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 480 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 481 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 482 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 483 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 484 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 485 |
+
images: Optional[torch.FloatTensor] = None,
|
| 486 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
| 487 |
+
bbox_list: Optional[torch.FloatTensor] = None,
|
| 488 |
+
image_grid_thws: Optional[torch.FloatTensor] = None,
|
| 489 |
+
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
| 490 |
+
|
| 491 |
+
if inputs_embeds is None:
|
| 492 |
+
(
|
| 493 |
+
input_ids,
|
| 494 |
+
position_ids,
|
| 495 |
+
attention_mask,
|
| 496 |
+
past_key_values,
|
| 497 |
+
inputs_embeds,
|
| 498 |
+
labels,
|
| 499 |
+
rope_deltas,
|
| 500 |
+
cache_position
|
| 501 |
+
) = self.prepare_inputs_labels_for_qwen2_5_vl_multimodal(
|
| 502 |
+
input_ids,
|
| 503 |
+
position_ids,
|
| 504 |
+
attention_mask,
|
| 505 |
+
past_key_values,
|
| 506 |
+
labels,
|
| 507 |
+
images,
|
| 508 |
+
images_aux,
|
| 509 |
+
bbox_list,
|
| 510 |
+
image_grid_thws
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
if rope_deltas is not None:
|
| 514 |
+
self.rope_deltas = rope_deltas
|
| 515 |
+
|
| 516 |
+
# Call base CausalLM forward, with possibly replaced multimodal embeddings
|
| 517 |
+
out = super().forward(
|
| 518 |
+
input_ids=input_ids,
|
| 519 |
+
attention_mask=attention_mask,
|
| 520 |
+
position_ids=position_ids,
|
| 521 |
+
past_key_values=past_key_values,
|
| 522 |
+
inputs_embeds=inputs_embeds,
|
| 523 |
+
labels=labels,
|
| 524 |
+
use_cache=use_cache,
|
| 525 |
+
output_attentions=output_attentions,
|
| 526 |
+
output_hidden_states=output_hidden_states,
|
| 527 |
+
rope_deltas=rope_deltas,
|
| 528 |
+
cache_position=cache_position,
|
| 529 |
+
second_per_grid_ts=second_per_grid_ts,
|
| 530 |
+
return_dict=return_dict
|
| 531 |
+
)
|
| 532 |
+
return out
|
| 533 |
+
|
| 534 |
+
# Prepare model input dict for autoregressive generation (for use with generation methods like generate())
|
| 535 |
+
def prepare_inputs_for_generation(
|
| 536 |
+
self,
|
| 537 |
+
input_ids,
|
| 538 |
+
past_key_values=None,
|
| 539 |
+
attention_mask=None,
|
| 540 |
+
inputs_embeds=None,
|
| 541 |
+
cache_position=None,
|
| 542 |
+
position_ids=None,
|
| 543 |
+
use_cache=True,
|
| 544 |
+
pixel_values=None,
|
| 545 |
+
pixel_values_videos=None,
|
| 546 |
+
image_grid_thw=None,
|
| 547 |
+
video_grid_thw=None,
|
| 548 |
+
second_per_grid_ts=None,
|
| 549 |
+
images: Optional[torch.FloatTensor] = None,
|
| 550 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
| 551 |
+
bbox_list: Optional[torch.FloatTensor] = None,
|
| 552 |
+
image_grid_thws: Optional[torch.FloatTensor] = None,
|
| 553 |
+
**kwargs,
|
| 554 |
+
):
|
| 555 |
+
# Wrap parent logic so extra multimodal kwargs are preserved
|
| 556 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 557 |
+
input_ids,
|
| 558 |
+
past_key_values=past_key_values,
|
| 559 |
+
attention_mask=attention_mask,
|
| 560 |
+
inputs_embeds=inputs_embeds,
|
| 561 |
+
cache_position=cache_position,
|
| 562 |
+
pixel_values=pixel_values,
|
| 563 |
+
pixel_values_videos=pixel_values_videos,
|
| 564 |
+
image_grid_thw=image_grid_thw,
|
| 565 |
+
video_grid_thw=video_grid_thw,
|
| 566 |
+
second_per_grid_ts=second_per_grid_ts,
|
| 567 |
+
images=images,
|
| 568 |
+
images_aux=images_aux,
|
| 569 |
+
bbox_list=bbox_list,
|
| 570 |
+
image_grid_thws=image_grid_thws,
|
| 571 |
+
)
|
| 572 |
+
return model_inputs
|
| 573 |
+
|
| 574 |
+
# Register our config and model with HuggingFace transformers registry
|
| 575 |
+
AutoConfig.register("omchat_qwen2_5_vl", OmChatQwen25VLConfig)
|
| 576 |
+
AutoModelForCausalLM.register(OmChatQwen25VLConfig, OmChatQwen25VLForCausalLM)
|
vlm_fo1/model/multimodal_encoder/__init__.py
ADDED
|
File without changes
|
vlm_fo1/model/multimodal_encoder/base_encoder.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AbsVisionTower(nn.Module):
|
| 6 |
+
@torch.no_grad()
|
| 7 |
+
def forward(self, images):
|
| 8 |
+
raise NotImplementedError
|
| 9 |
+
|
| 10 |
+
@property
|
| 11 |
+
def dummy_feature(self):
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
def dtype(self):
|
| 16 |
+
raise NotImplementedError
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def device(self):
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def config(self):
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def hidden_size(self):
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def num_patches(self):
|
| 33 |
+
raise NotImplementedError
|
vlm_fo1/model/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Builders for different vision tower backbones (MM encoder visual modules)
|
| 2 |
+
from .qwen2_5_vl_encoder import Qwen2_5_VlVisionTower # Main Qwen2.5 vision tower
|
| 3 |
+
from .davit_aux_encoder import DavitVisionTower as DavitVisionTowerAux # Auxiliary DaViT vision tower
|
| 4 |
+
|
| 5 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
+
"""
|
| 7 |
+
Use model config to construct the main vision tower.
|
| 8 |
+
|
| 9 |
+
vision_tower_cfg: should have attribute mm_vision_tower
|
| 10 |
+
Returns: instance of configured vision backbone
|
| 11 |
+
"""
|
| 12 |
+
vision_tower_name = getattr(vision_tower_cfg, 'mm_vision_tower', None)
|
| 13 |
+
# print(vision_tower_cfg) # Debug print of the config being used
|
| 14 |
+
|
| 15 |
+
# Check for the Qwen2.5-VL vision model in tower name
|
| 16 |
+
if "qwen2.5-vl" in vision_tower_name.lower():
|
| 17 |
+
return Qwen2_5_VlVisionTower(vision_tower_name, args=vision_tower_cfg, **kwargs)
|
| 18 |
+
|
| 19 |
+
# Raise a clear error for unknown towers
|
| 20 |
+
raise ValueError(f'Unknown vision tower: {vision_tower_name}')
|
| 21 |
+
|
| 22 |
+
def build_vision_tower_aux(vision_tower_cfg, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Use model config to construct the auxiliary (helper) vision tower.
|
| 25 |
+
|
| 26 |
+
vision_tower_cfg: should have attribute mm_vision_tower_aux
|
| 27 |
+
Returns: instance of configured auxiliary vision backbone
|
| 28 |
+
"""
|
| 29 |
+
vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', None)
|
| 30 |
+
# Optionally print config for debugging
|
| 31 |
+
# print(vision_tower_cfg)
|
| 32 |
+
|
| 33 |
+
# Check for the DaViT auxiliary vision model in tower name
|
| 34 |
+
if 'davit' in vision_tower_aux.lower():
|
| 35 |
+
return DavitVisionTowerAux(vision_tower_aux, args=vision_tower_cfg, **kwargs)
|
| 36 |
+
|
| 37 |
+
# Raise a clear error if tower type is unknown
|
| 38 |
+
raise ValueError(f'Unknown aux vision tower: {vision_tower_aux}')
|