diff --git a/README.md b/README.md index f404e1d12cf0d4923035a2aad37021ffdf511435..15d5b56a0b98e30975c384c662130a14452878bc 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,13 @@ --- title: VLM FO1 3B Demo -emoji: 📊 +emoji: 🐠 colorFrom: green -colorTo: indigo +colorTo: yellow sdk: gradio sdk_version: 5.49.1 -app_file: app.py +app_file: run.sh pinned: false license: apache-2.0 short_description: VLM-FO1-3B-Demo --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..1fdb15015df7e8eaa936439bafd461f2a3d4f4af --- /dev/null +++ b/demo/gradio_demo.py @@ -0,0 +1,253 @@ +import gradio as gr +from PIL import Image, ImageDraw, ImageFont +import re +import numpy as np +from skimage.measure import label, regionprops +from skimage.morphology import binary_dilation, disk +from detect_tools.upn import UPNWrapper +from vlm_fo1.model.builder import load_pretrained_model +from vlm_fo1.mm_utils import ( + prepare_inputs, + extract_predictions_to_indexes, +) +from vlm_fo1.task_templates import * +import torch + + +TASK_TYPES = { + "OD/REC": OD_template, + "ODCounting": OD_Counting_template, + "Region_OCR": "Please provide the ocr results of these regions in the image.", + "Brief_Region_Caption": "Provide a brief description for these regions in the image.", + "Detailed_Region_Caption": "Provide a detailed description for these regions in the image.", + "Grounding": Grounding_template, + "Viusal_Region_Reasoning": Viusal_Region_Reasoning_template, +} + + + +def detect_model(image, threshold=0.3): + proposals = upn_model.inference(image) + filtered_proposals = upn_model.filter(proposals, min_score=threshold) + return filtered_proposals['original_xyxy_boxes'][0][:100] + + +def multimodal_model(image, bboxes, text): + if '' in text: + print(text) + parts = [part.replace('\\n', '\n') for part in re.split(rf'()', text) if part.strip()] + print(parts) + content = [] + for part in parts: + if part == '': + content.append({"type": "image_url", "image_url": {"url": image}}) + else: + content.append({"type": "text", "text": part}) + else: + content = [{ + "type": "image_url", + "image_url": { + "url": image + } + }, { + "type": "text", + "text": text + }] + + messages = [ + { + "role": "user", + "content": content, + "bbox_list": bboxes + } + ] + generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages, + max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False) + with torch.inference_mode(): + output_ids = model.generate(**generation_kwargs) + outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip() + print("========output========\n", outputs) + + prediction_dict = extract_predictions_to_indexes(outputs) + + ans_bbox_json = [] + ans_bbox_list = [] + for k, v in prediction_dict.items(): + for box_index in v: + box_index = int(box_index) + if box_index < len(bboxes): + current_bbox = bboxes[box_index] + ans_bbox_json.append({ + "region_index": f"", + "xmin": current_bbox[0], + "ymin": current_bbox[1], + "xmax": current_bbox[2], + "ymax": current_bbox[3], + "label": k + }) + ans_bbox_list.append(current_bbox) + + return outputs, ans_bbox_json, ans_bbox_list + + + +def draw_bboxes(image, bboxes, labels=None): + image = image.copy() + draw = ImageDraw.Draw(image) + + for bbox in bboxes: + draw.rectangle(bbox, outline="red", width=3) + return image + + +def extract_bbox_and_original_image(edited_image: dict): + original_image = edited_image["background"] + bbox_list = [] + + if original_image is None: + return None, "Error, Please upload an image." + + if edited_image["layers"] is None or len(edited_image["layers"]) == 0: + return original_image, [] + + drawing_layer = edited_image["layers"][0] + alpha_channel = drawing_layer.getchannel('A') + alpha_np = np.array(alpha_channel) + + binary_mask = alpha_np > 0 + + structuring_element = disk(5) + dilated_mask = binary_dilation(binary_mask, structuring_element) + + labeled_image = label(dilated_mask) + regions = regionprops(labeled_image) + + for prop in regions: + y_min, x_min, y_max, x_max = prop.bbox + bbox_list.append((x_min, y_min, x_max, y_max)) + + return original_image, bbox_list + + +def process(image, prompt, threshold): + image, bbox_list = extract_bbox_and_original_image(image) + image = image.convert('RGB') + + if len(bbox_list) == 0: + # Get bboxes from detection model + bboxes = detect_model(image, threshold) + else: + bboxes = bbox_list + for idx in range(len(bboxes)): + prompt += f'' + + ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt) + + + image_with_opn = draw_bboxes(image, bboxes) + + annotated_bboxes = [] + if len(ans_bbox_json) > 0: + for item in ans_bbox_json: + annotated_bboxes.append( + ((int(item['xmin']), int(item['ymin']), int(item['xmax']), int(item['ymax'])), item['label']) + ) + annotated_image = (image, annotated_bboxes) + + return annotated_image, image_with_opn, ans, ans_bbox_json + + +def show_label_input(choice): + return gr.update(visible=(choice == "OmDet")) + + +def update_btn(is_processing): + if is_processing: + return gr.update(value="Processing...", interactive=False) + else: + return gr.update(value="Submit", interactive=True) + + +def launch_demo(): + with gr.Blocks() as demo: + gr.Markdown("## VLM-FO1 Demo") + gr.Markdown(""" + **Instructions:** + 1. Upload an image, then you can either draw circular regions on it using the red brush as the input regions or let the detection model detect the regions for you. + 2. Select a task template and replace the [WRITE YOUR INPUT HERE] with your input targets, or write your own prompt.\n + For example, if you want to detect "person" and "dog", you can replace the [WRITE YOUR INPUT HERE] with "person, dog".\n + 3. Adjust the detection threshold if needed + 4. Click Submit to get results + """) + + with gr.Row(): + with gr.Column(): + img_input_draw = gr.ImageEditor( + label="Image Input", + image_mode="RGBA", + type="pil", + sources=['upload'], + brush=gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2), + interactive=True + ) + + gr.Markdown("### Prompt & Parameters") + + def set_prompt_from_template(selected_task): + return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]")) + + task_type_input = gr.Dropdown( + choices=list(TASK_TYPES.keys()), + value="OD/REC", + label="Prompt Templates", + info="Select the prompt template for the task, or write your own prompt." + ) + + prompt_input = gr.Textbox( + label="Task Prompt", + value=TASK_TYPES["OD/REC"].format("[WRITE YOUR INPUT HERE]"), + lines=2, + ) + + task_type_input.change( + set_prompt_from_template, + inputs=task_type_input, + outputs=prompt_input + ) + + + threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold") + submit_btn = gr.Button("Submit", variant="primary") + + with gr.Column(): + with gr.Accordion("Detection Result", open=True): + image_output_opn = gr.Image(label="Detection Result") + + image_output = gr.AnnotatedImage(label="Multimodal Model Output", height=500) + + result_output = gr.Textbox(label="Multimodal Model Output") + ans_bbox_json = gr.JSON(label="Extracted Detection Output") + + submit_btn.click(update_btn, inputs=[gr.State(True)], outputs=[submit_btn], queue=False).then( + process, + inputs=[img_input_draw, prompt_input, threshold_input], + outputs=[image_output, image_output_opn, result_output, ans_bbox_json], + queue=True + ).then(update_btn, inputs=[gr.State(False)], outputs=[submit_btn], queue=False) + + return demo + +if __name__ == "__main__": + model_path = './resources/VLM-FO1_Qwen2.5-VL-3B-v01' + upn_ckpt_path = "./resources/upn_large.pth" + tokenizer, model, image_processors = load_pretrained_model( + model_path=model_path, + device="cuda:0", + ) + upn_model = UPNWrapper(upn_ckpt_path) + + demo = launch_demo() + demo.launch() + + + diff --git a/detect_tools/upn/__init__.py b/detect_tools/upn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b590cc00c0a4308282ae60f372f2c9f364e2ffd6 --- /dev/null +++ b/detect_tools/upn/__init__.py @@ -0,0 +1,45 @@ +from . import models +from .builder import ( + ARCHITECTURES, + BACKBONES, + DECODERS, + ENCODERS, + POS_EMBEDDINGS, + build_architecture, + build_backbone, + build_decoder, + build_encoder, + build_position_embedding, +) +from .inference_wrapper import UPNWrapper +from .models.architecture import * +from .models.backbone import * +from .models.decoder import * +from .models.encoder import * +from .models.module import * +from .models.utils import * + +__all__ = [ + "BACKBONES", + "POS_EMBEDDINGS", + "ENCODERS", + "DECODERS", + "ARCHITECTURES", + "build_backbone", + "build_position_embedding", + "build_encoder", + "build_decoder", + "build_architecture", + "UPNWrapper", +] + +__all__ += ( + models.module.__all__ + + models.utils.__all__ + + models.architecture.__all__ + + models.backbone.__all__ + + models.encoder.__all__ + + models.decoder.__all__ + + models.module.__all__ + + models.utils.__all__ +) diff --git a/detect_tools/upn/builder.py b/detect_tools/upn/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..babd62b3e14c428540193e368f49fb3752ae0650 --- /dev/null +++ b/detect_tools/upn/builder.py @@ -0,0 +1,39 @@ +from mmengine import Registry, build_from_cfg + +BACKBONES = Registry("backbone") +POS_EMBEDDINGS = Registry("position_embedding") +FUSERS = Registry("fuser") +ENCODERS = Registry("encoder") +DECODERS = Registry("decoder") +ARCHITECTURES = Registry("architecture") + + +def build_backbone(cfg): + """Build encoder.""" + return build_from_cfg(cfg, BACKBONES) + + +def build_position_embedding(cfg): + """Build position embedding.""" + return build_from_cfg(cfg, POS_EMBEDDINGS) + + +def build_fuser(cfg): + """Build fuser.""" + return build_from_cfg(cfg, FUSERS) + + +def build_encoder(cfg): + """Build encoder.""" + return build_from_cfg(cfg, ENCODERS) + + +def build_decoder(cfg): + """Build decoder.""" + return build_from_cfg(cfg, DECODERS) + + +def build_architecture(cfg): + """Build architecture.""" + + return build_from_cfg(cfg, ARCHITECTURES) diff --git a/detect_tools/upn/configs/upn_large.py b/detect_tools/upn/configs/upn_large.py new file mode 100644 index 0000000000000000000000000000000000000000..d22b76c22be1474669954a27eb6d173cbc22c8eb --- /dev/null +++ b/detect_tools/upn/configs/upn_large.py @@ -0,0 +1,73 @@ +transformer_cfg = dict( + type="DeformableTransformer", + num_queries=900, + encoder_cfg=dict( + type="UPNEncoder", + encoder_layer_cfg=dict( + type="DeformableTransformerEncoderLayer", + activation="relu", + d_model=256, + dropout=0.0, + d_ffn=2048, + n_heads=8, + n_levels=5, + ), + d_model=256, + num_layers=6, + use_checkpoint=False, + use_transformer_ckpt=False, + ), + decoder_cfg=dict( + type="UPNDecoder", + decoder_layer_cfg=dict( + type="DeformableTransformerDecoderLayer", + activation="relu", + d_model=256, + n_heads=8, + dropout=0.0, + d_ffn=2048, + n_levels=5, + ), + d_model=256, + return_intermediate=True, + num_layers=6, + rm_dec_query_scale=True, + use_detached_boxes_dec_out=False, + ), + learnable_tgt_init=True, + random_refpoints_xy=False, + num_feature_levels=5, + two_stage_bbox_embed_share=False, + two_stage_class_embed_share=False, + two_stage_keep_all_tokens=False, + two_stage_learn_wh=False, + two_stage_type="standard", + binary_query_selection=False, +) + +vision_backbone = dict( + type="SwinWrapper", + backbone_cfg="swin_L_384_22k", + lr_backbone=1e-05, + dilation=False, + return_interm_indices=[0, 1, 2, 3], + backbone_freeze_keywords=None, + backbone_ckpt_path=None, + use_checkpoint=False, + position_embedding_cfg=dict( + type="PositionEmbeddingSineHW", + normalize=True, + num_pos_feats=128, + temperatureH=20, + temperatureW=20, + ), +) + +model = dict( + type="UPN", + vision_backbone_cfg=vision_backbone, + transformer_cfg=transformer_cfg, + num_queries=900, + dec_pred_bbox_embed_share=True, + dec_pred_class_embed_share=True, +) diff --git a/detect_tools/upn/inference_wrapper.py b/detect_tools/upn/inference_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c523ae67374781b5a18a858c9f07ad7fd3383880 --- /dev/null +++ b/detect_tools/upn/inference_wrapper.py @@ -0,0 +1,237 @@ +import copy +import os +from typing import Dict, List, Union + +import numpy as np +import torch +from mmengine import Config +from PIL import Image +from torchvision.ops import nms + +import detect_tools.upn.transforms.transform as T +from detect_tools.upn import build_architecture +from detect_tools.upn.models.module import nested_tensor_from_tensor_list + + +def build_model( + ckpt_path: str, +): + current_path = os.path.dirname(os.path.abspath(__file__)) + config_file = f"configs/upn_large.py" + config_path = os.path.join(current_path, config_file) + model_cfg = Config.fromfile(config_path).model + model = build_architecture(model_cfg) + checkpoint = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + return model + + +class UPNWrapper: + """A wrapper class for the UPN model. + + Args: + ckpt_path (str): The path to the model checkpoint + """ + + def __init__(self, ckpt_path: str): + + self.model = build_model(ckpt_path) + self.model.eval() + self.model.to("cuda") + + def inference( + self, + image: List[Union[str, Image.Image]], + prompt_type: str = 'fine_grained_prompt', + ): + """Single image prediction. + + Args: + image List[Union[str, Image.Image]]: A list of image path or + PIL.Image.Image object. + prompt_type (str): The type of prompt to use for the prediction. Choice in + ['fine_grained_prompt', 'coarse_grained_prompt']. + + Returns: + Dict: Return dict in format: + { + "original_xyxy_boxes": (np.ndarray): Original prediction boxes in shape (batch_size, 900, 4), + "scores": (np.ndarray): Score in shape (batch_size, N) + } + """ + if not isinstance(image, list): + image = [image] + input_images, image_sizes = self.construct_input(image) + outputs = self._inference(input_images, prompt_type) + post_processed_outputs = self.postprocess(outputs, image_sizes) + return post_processed_outputs + + def _inference(self, input_images: List[torch.Tensor], prompt_type: str): + """Inference for T-Rex2 + + Args: + input_images (List[torch.Tensor]): Transformed Image + + Retunrs: + (Dict): Return dict with keys: + - query_features: (torch.Tensor): Query features in shape (batch_size, N, 256) + - pred_boxes: (torch.Tensor): Normalized prediction boxes in shape (batch_size, N, 4), + in cxcywh format + """ + input_images = nested_tensor_from_tensor_list(input_images) + input_images = input_images.to("cuda") + with torch.no_grad(): + outputs = self.model(input_images, prompt_type) + return outputs + + def construct_input(self, image: List[Union[str, Image.Image]]): + """Construct input for the model + + Args: + image (image: Union[List[Union[str, Image.Image]], torch.Tensor]): A list of image path or + PIL.Image.Image object. If the length of the list is more than 1, the model w`ill + perform batch inference. + + Returns: + Tuple[torch.Tensor, List[List[int]]]: A tuple containing the + input images, and the sizes of the input images. + """ + input_images = [] + image_sizes = [] + for _, img in enumerate(image): + if isinstance(img, str): + img = Image.open(img) + elif isinstance(img, Image.Image): + img = img + else: + raise ValueError( + "image must be either a string or a PIL.Image.Image object" + ) + W, H = img.size + image_sizes.append([H, W]) + # add image in tensor format + input_images.append(self.transform_image(img)) + return input_images, image_sizes + + def transform_image(self, image_pil: Image) -> Image: + """apply a set of transformations to a cv2 load image. + + Args: + image_path (str): The path to the image file. + + Returns: + Tuple[PIL.Image, torch.Tensor]: A tuple containing the original PIL Image and the + transformed image as a PyTorch tensor. + """ + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + transformed_image, _ = transform(image_pil, None) # 3, h, w + return transformed_image + + def postprocess( + self, + outputs: Dict[str, torch.Tensor], + image_pil_sizes: List[List[int]] = None, + ): + boxes = outputs["pred_boxes"].cpu() + scores = ( + outputs["pred_logits"].sigmoid().cpu() if "pred_logits" in outputs else None + ) + normalized_xyxy_boxes = [] + original_xyxy_boxes = [] + for batch_idx, (H, W) in enumerate(image_pil_sizes): + batch_boxes = boxes[batch_idx] # (num_queries, 4) + # from (cx, cy, w, h) to (x1, y1, x2, y2) + batch_boxes[:, 0] = batch_boxes[:, 0] - batch_boxes[:, 2] / 2 + batch_boxes[:, 1] = batch_boxes[:, 1] - batch_boxes[:, 3] / 2 + batch_boxes[:, 2] = batch_boxes[:, 0] + batch_boxes[:, 2] + batch_boxes[:, 3] = batch_boxes[:, 1] + batch_boxes[:, 3] + normalized_xyxy_boxes.append(copy.deepcopy(batch_boxes)) + # scale boxes + original_boxes = ( + batch_boxes.clone() + ) # Copy the normalized boxes to scale to original sizes + original_boxes[:, 0] = original_boxes[:, 0] * W + original_boxes[:, 1] = original_boxes[:, 1] * H + original_boxes[:, 2] = original_boxes[:, 2] * W + original_boxes[:, 3] = original_boxes[:, 3] * H + original_xyxy_boxes.append(original_boxes) + + original_xyxy_boxes = torch.stack(original_xyxy_boxes) + original_xyxy_boxes = original_xyxy_boxes.numpy() + + # sort everything by score from highest to lowest + sorted_original_boxes = [] + sorted_scores = [] + for i in range(len(normalized_xyxy_boxes)): + scores_i = scores[i] if scores is not None else None + # sort in descending order + sorted_indices = scores_i.squeeze(-1).argsort(descending=True) + sorted_original_boxes.append(original_xyxy_boxes[i][sorted_indices]) + sorted_scores.append(scores_i[sorted_indices]) + + original_xyxy_boxes = np.stack(sorted_original_boxes) + scores = torch.stack(sorted_scores) + + return dict( + original_xyxy_boxes=original_xyxy_boxes, + scores=scores, + ) + + def filter(self, result: Dict, min_score: float, nms_value: float = 0.8): + """Filter the UPN detection result. Only keep boxes with score above min_score + and apply Non-Maximum Suppression (NMS) to filter overlapping boxes. + + Args: + result (Dict): A dictionary containing detection results with 'original_xyxy_boxes' and 'scores'. + min_score (float): Minimum score threshold for keeping a box. + nms_value (float): NMS threshold for filtering boxes. + + Returns: + Dict: Filtered result containing 'original_xyxy_boxes' and 'scores' with the filtered boxes. + """ + filtered_result = {"original_xyxy_boxes": [], "scores": []} + + for boxes, scores in zip( + np.array(result["original_xyxy_boxes"]), result["scores"].numpy() + ): + # Filter out boxes with score below min_score + keep = scores >= min_score + boxes = boxes[keep[:, 0]] + scores = scores[keep[:, 0]][:, 0] + + if len(boxes) == 0: + return filtered_result + + # Convert to torch tensors + boxes = torch.tensor(boxes, dtype=torch.float32) + scores = torch.tensor(scores, dtype=torch.float32) + + # Apply Non-Maximum Suppression (NMS) + if nms_value > 0: + keep_indices = nms(boxes, scores, nms_value) + else: + keep_indices = torch.arange(len(boxes)) + + # Keep only the boxes that passed NMS + filtered_boxes = boxes[keep_indices].numpy().astype(np.int32) + filtered_scores = scores[keep_indices].numpy() + + # Sort the boxes by score in descending order + sorted_indices = np.argsort(filtered_scores)[::-1] + filtered_boxes = filtered_boxes[sorted_indices] + filtered_scores = filtered_scores[sorted_indices] + + # Round the scores to two decimal places + filtered_scores = [round(score, 2) for score in filtered_scores] + + # Store the filtered boxes and scores in the result dictionary + filtered_result["original_xyxy_boxes"].append(filtered_boxes.tolist()) + filtered_result["scores"].append(filtered_scores) + + return filtered_result diff --git a/detect_tools/upn/models/architecture/__init__.py b/detect_tools/upn/models/architecture/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d13d41f89544e3354c30d6e64f5636784f1b747c --- /dev/null +++ b/detect_tools/upn/models/architecture/__init__.py @@ -0,0 +1,4 @@ +from .deformable_transformer import DeformableTransformer +from .upn_model import UPN + +__all__ = ["UPN", "DeformableTransformer"] diff --git a/detect_tools/upn/models/architecture/deformable_transformer.py b/detect_tools/upn/models/architecture/deformable_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f94d3f0127b61dd09d48e7eadfcf3dddc64740c --- /dev/null +++ b/detect_tools/upn/models/architecture/deformable_transformer.py @@ -0,0 +1,336 @@ +import math +from typing import Dict, List + +import torch +import torch.nn as nn + +from detect_tools.upn import ARCHITECTURES, build_decoder, build_encoder +from detect_tools.upn.models.utils import (gen_encoder_output_proposals, + inverse_sigmoid) +from detect_tools.upn.ops.modules import MSDeformAttn + + +@ARCHITECTURES.register_module() +class DeformableTransformer(nn.Module): + """Implementation of Deformable DETR. + + Args: + encoder_cfg (Dict): Config for the TransformerEncoder. + decoder_cfg (Dict): Config for the TransformerDecoder. + num_queries (int): Number of queries. This is for matching part. Default: 900. + d_model (int): Dimension of the model. Default: 256. + num_feature_levels (int): Number of feature levels. Default: 1. + binary_query_selection (bool): Whether to use binary query selection. Default: False. + When using binary query selection, a linear with out channe =1 will be used to select + topk proposals. Otherwise, we will use ContrastiveAssign to select topk proposals. + learnable_tgt_init (bool): Whether to use learnable target init. Default: True. If False, + we will use the topk encoder features as the target init. + random_refpoints_xy (bool): Whether to use random refpoints xy. This is only used when + two_stage_type is not 'no'. Default: False. If True, we will use random refpoints xy. + two_stage_type (str): Type of two stage. Default: 'standard'. Options: 'no', 'standard' + two_stage_learn_wh (bool): Whether to learn the width and height of anchor boxes. Default: False. + two_stage_keep_all_tokens (bool): If False, the returned hs_enc, ref_enc, init_box_proposal + will only be the topk proposals. Otherwise, we will return all the proposals from the + encoder. Default: False. + two_stage_bbox_embed_share (bool): Whether to share the bbox embedding between the two stages. + Default: False. + two_stage_class_embed_share (bool): Whether to share the class embedding between the two stages. + rm_self_attn_layers (List[int]): The indices of the decoder layers to remove self-attention. + Default: None. + rm_detach (bool): Whether to detach the decoder output. Default: None. + embed_init_tgt (bool): If true, the target embedding is learnable. Otherwise, we will use + the topk encoder features as the target init. Default: True. + """ + + def __init__( + self, + encoder_cfg: Dict, + decoder_cfg: Dict, + mask_decoder_cfg: Dict = None, + num_queries: int = 900, + d_model: int = 256, + num_feature_levels: int = 4, + binary_query_selection: bool = False, + # init query (target) + learnable_tgt_init=True, + random_refpoints_xy=False, + # for two stage + two_stage_type: str = "standard", + two_stage_learn_wh: bool = False, + two_stage_keep_all_tokens: bool = False, + two_stage_bbox_embed_share: bool = False, + two_stage_class_embed_share: bool = False, + # evo of #anchors + rm_self_attn_layers: List[int] = None, + # for detach + rm_detach: bool = None, + with_encoder_out: bool = True, + ) -> None: + super().__init__() + self.binary_query_selection = binary_query_selection + self.num_queries = num_queries + self.num_feature_levels = num_feature_levels + self.rm_self_attn_layers = rm_self_attn_layers + self.d_model = d_model + self.two_stage_bbox_embed_share = two_stage_bbox_embed_share + self.two_stage_class_embed_share = two_stage_class_embed_share + + if self.binary_query_selection: + self.binary_query_selection_layer = nn.Linear(d_model, 1) + + # build encoder + self.encoder = build_encoder(encoder_cfg) + + # build decoder + self.decoder = build_decoder(decoder_cfg) + self.num_decoder_layers = self.decoder.num_layers + + # build sole mask decoder + if mask_decoder_cfg is not None: + self.mask_decoder = build_decoder(mask_decoder_cfg) + else: + self.mask_decoder = None + # level embedding + if num_feature_levels > 1: + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + # learnable target embedding + self.learnable_tgt_init = learnable_tgt_init + assert learnable_tgt_init, "why not learnable_tgt_init" + + self.tgt_embed = nn.Embedding(num_queries, d_model) + nn.init.normal_(self.tgt_embed.weight.data) + + # for two stage + # TODO: this part is really confusing + self.two_stage_type = two_stage_type + self.two_stage_learn_wh = two_stage_learn_wh + self.two_stage_keep_all_tokens = two_stage_keep_all_tokens + assert two_stage_type in [ + "no", + "standard", + ], "unknown param {} of two_stage_type".format(two_stage_type) + self.with_encoder_out = with_encoder_out + if two_stage_type == "standard": + # anchor selection at the output of encoder + if with_encoder_out: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + + if two_stage_learn_wh: + # import ipdb; ipdb.set_trace() + self.two_stage_wh_embedding = nn.Embedding(1, 2) + else: + self.two_stage_wh_embedding = None + + elif two_stage_type == "no": + self.init_ref_points( + num_queries, random_refpoints_xy + ) # init self.refpoint_embed + + self.enc_out_class_embed = None # this will be initialized outside of the model + self.enc_out_bbox_embed = None # this will be initialized outside of the model + + # remove some self_attn_layers or rm_detach + self._reset_parameters() + + self.rm_self_attn_layers = rm_self_attn_layers + if rm_self_attn_layers is not None: + # assert len(rm_self_attn_layers) == num_decoder_layers + print( + "Removing the self-attn in {} decoder layers".format( + rm_self_attn_layers + ) + ) + for lid, dec_layer in enumerate(self.decoder.layers): + if lid in rm_self_attn_layers: + dec_layer.rm_self_attn_modules() + + self.rm_detach = rm_detach + if self.rm_detach: + assert isinstance(rm_detach, list) + assert any([i in ["enc_ref", "enc_tgt", "dec"] for i in rm_detach]) + self.decoder.rm_detach = rm_detach + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if self.num_feature_levels > 1 and self.level_embed is not None: + nn.init.normal_(self.level_embed) + + if self.two_stage_learn_wh: + nn.init.constant_( + self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05)) + ) + + def init_ref_points(self, num_queries: int, random_refpoints_xy: bool = False): + """Initialize learnable reference points for each query. + + Args: + num_queries (int): number of queries + random_refpoints_xy (bool, optional): whether to init the refpoints randomly. + Defaults to False. + """ + self.refpoint_embed = nn.Embedding(num_queries, 4) + if random_refpoints_xy: + self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) + self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( + self.refpoint_embed.weight.data[:, :2] + ) + self.refpoint_embed.weight.data[:, :2].requires_grad = False + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward( + self, + src_flatten: torch.Tensor, + lvl_pos_embed_flatten: torch.Tensor, + level_start_index: List[int], + spatial_shapes: List[torch.Tensor], + valid_ratios: List[torch.Tensor], + mask_flatten: torch.Tensor, + prompt_type: str, + ) -> List[torch.Tensor]: + """Forward function.""" + memory = self.encoder( + src_flatten, + pos=lvl_pos_embed_flatten, + level_start_index=level_start_index, + spatial_shapes=spatial_shapes, + valid_ratios=valid_ratios, + key_padding_mask=mask_flatten, + ) + batch_size = src_flatten.shape[0] + crop_region_features = torch.zeros(batch_size, 1, self.d_model).to( + memory.device + ) + if prompt_type == "fine_grained_prompt": + crop_region_features = ( + self.fine_grained_prompt.weight[0] + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + elif prompt_type == "coarse_grained_prompt": + crop_region_features = ( + self.coarse_grained_prompt.weight[0] + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + pad_mask = torch.ones(batch_size, 1).to(crop_region_features.device).bool() + self_attn_mask = torch.ones(batch_size, 1, 1).to(crop_region_features.device) + ref_dict = dict( + encoded_ref_feature=crop_region_features, + pad_mask=pad_mask, + self_attn_mask=self_attn_mask, + prompt_type="universal_prompt", + ) + + ( + refpoint_embed, + tgt, + init_box_proposal, + ) = self.get_two_stage_proposal(memory, mask_flatten, spatial_shapes, ref_dict) + + hs, references = self.decoder( + tgt=tgt.transpose(0, 1), + tgt_key_padding_mask=None, + memory=memory.transpose(0, 1), + memory_key_padding_mask=mask_flatten, + pos=lvl_pos_embed_flatten.transpose(0, 1), + refpoints_unsigmoid=refpoint_embed.transpose(0, 1), + level_start_index=level_start_index, + spatial_shapes=spatial_shapes, + valid_ratios=valid_ratios, + tgt_mask=None, + # we ~ the mask . False means use the token; True means pad the token + ) + hs_enc = ref_enc = None + return ( + hs, + references, + ref_dict, + ) + + def get_two_stage_proposal( + self, + memory: torch.Tensor, + mask_flatten: torch.Tensor, + spatial_shapes: List[torch.Tensor], + ref_dict: Dict, + ) -> List[torch.Tensor]: + """Two stage proposal generation for decoder + + Args: + memory (torch.Tensor): Image encoded feature. [bs, n, 256] + mask_flatten (torch.Tensor): Flattened mask. [bs, n] + spatial_shapes (List[torch.Tensor]): Spatial shapes of each feature map. [bs, num_levels, 2] + refpoint_embed_dn (torch.Tensor): Denosing refpoint embedding. [bs, num_dn_queries, 256] + tgt_dn (torch.Tensor): Denosing target embedding. [bs, num_dn_queries, 256] + ref_dict (Dict): A dict containing all kinds of reference image related features. + """ + bs = memory.shape[0] + input_hw = None + output_memory, output_proposals = gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes, input_hw + ) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + + if self.binary_query_selection: # Unused + topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1) + else: + if ref_dict is not None: + enc_outputs_class_unselected = self.enc_out_class_embed( + output_memory, ref_dict + ) # this is not a linear layer for prediction. But contrastive similaryity, shape [B, len_image, len_text] + else: + enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) + topk_logits = enc_outputs_class_unselected.max(-1)[ + 0 + ] # shape [B, len_image] + enc_outputs_coord_unselected = ( + self.enc_out_bbox_embed(output_memory) + output_proposals + ) # (bs, \sum{hw}, 4) unsigmoid + topk = self.num_queries + + try: + topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq + except: + raise ValueError(f"dadad {topk_logits.shape}") + + # gather boxes + refpoint_embed_undetach = torch.gather( + enc_outputs_coord_unselected, + 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4), + ) # unsigmoid + refpoint_embed_ = refpoint_embed_undetach.detach() + init_box_proposal = torch.gather( + output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ).sigmoid() # sigmoid + # gather tgt + tgt_undetach = torch.gather( + output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + ) + tgt_ = ( + self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) + ) # nq, bs, d_model + refpoint_embed, tgt = refpoint_embed_, tgt_ + + return ( + refpoint_embed, + tgt, + init_box_proposal, + ) diff --git a/detect_tools/upn/models/architecture/upn_model.py b/detect_tools/upn/models/architecture/upn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..682562104d8c08ca7fec3548232cb03314416e70 --- /dev/null +++ b/detect_tools/upn/models/architecture/upn_model.py @@ -0,0 +1,343 @@ +import copy +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from detect_tools.upn import ARCHITECTURES, build_architecture, build_backbone +from detect_tools.upn.models.module import (MLP, ContrastiveAssign, NestedTensor, + nested_tensor_from_tensor_list) +from detect_tools.upn.models.utils import inverse_sigmoid + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +@ARCHITECTURES.register_module() +class UPN(nn.Module): + """Implementation of UPN""" + + def __init__( + self, + vision_backbone_cfg: Dict, + transformer_cfg: Dict, + num_queries: int, + dec_pred_class_embed_share=True, + dec_pred_bbox_embed_share=True, + decoder_sa_type="sa", + ): + super().__init__() + # build vision backbone + self.backbone = build_backbone(vision_backbone_cfg) + # build transformer + self.transformer = build_architecture(transformer_cfg) + + self.hidden_dim = self.transformer.d_model + + # for dn training + self.num_queries = num_queries + self.num_feature_levels = self.transformer.num_feature_levels + + # prepare projection layer for vision feature + self.input_proj = self.prepare_vision_feature_projection_layer( + self.backbone, + self.transformer.num_feature_levels, + self.hidden_dim, + self.transformer.two_stage_type, + ) + # prepare prediction head + self.prepare_prediction_head( + dec_pred_class_embed_share, + dec_pred_bbox_embed_share, + self.hidden_dim, + self.transformer.num_decoder_layers, + ) + + self.decoder_sa_type = decoder_sa_type + assert decoder_sa_type in ["sa", "ca_label", "ca_content"] + # self.replace_sa_with_double_ca = replace_sa_with_double_ca + + for layer in self.transformer.decoder.layers: + layer.label_embedding = None + self.label_embedding = None + + # build a unversal token + self.transformer.fine_grained_prompt = nn.Embedding(1, self.hidden_dim) + self.transformer.coarse_grained_prompt = nn.Embedding(1, self.hidden_dim) + + self._reset_parameters() + + def forward(self, samples: NestedTensor, prompt_type: str = None) -> Dict: + """Foward function""" + self.device = samples.device + + ( + src_flatten, + lvl_pos_embed_flatten, + level_start_index, + spatial_shapes, + valid_ratios, + mask_flatten, + ) = self.forward_backbone_encoder(samples) + + ( + hs, + reference, + ref_dict, + ) = self.transformer( + src_flatten, + lvl_pos_embed_flatten, + level_start_index, + spatial_shapes, + valid_ratios, + mask_flatten, + prompt_type, + ) + + # deformable-detr-line anchor update + outputs_coord_list = [] + outputs_class = [] + + for layer_idx, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( + zip(reference[:-1], self.bbox_embed, hs) + ): + layer_delta_unsig = layer_bbox_embed(layer_hs) + layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) + layer_outputs_unsig = layer_outputs_unsig.sigmoid() + outputs_coord_list.append(layer_outputs_unsig) + + outputs_coord_list = torch.stack(outputs_coord_list) + + if ref_dict is None: + # build a mock outputs_class for mask_dn training + outputs_class = torch.zeros( + outputs_coord_list.shape[0], + outputs_coord_list.shape[1], + outputs_coord_list.shape[2], + self.hidden_dim, + ) + else: + outputs_class = torch.stack( + [ + layer_cls_embed(layer_hs, ref_dict) + for layer_cls_embed, layer_hs in zip(self.class_embed, hs) + ] + ) + + out = { + "pred_logits": outputs_class[-1], + "pred_boxes": outputs_coord_list[-1], + } + out["ref_dict"] = ref_dict + return out + + def forward_backbone_encoder(self, samples: NestedTensor) -> Tuple: + # pass through backbone + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, poss = self.backbone(samples) + # project features + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) # downsample the feature map to 256 + masks.append(mask) + assert mask is not None + + if self.num_feature_levels > len( + srcs + ): # add more feature levels by downsampling the last feature map + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( + torch.bool + )[0] + pos_l = self.backbone.forward_pos_embed_only( + NestedTensor(src, mask) + ).to(src.dtype) + srcs.append(src) + masks.append(mask) + poss.append(pos_l) + + # prepare input for encoder with the following steps: + # 1. flatten the feature maps and masks + # 2. Add positional embedding and level embedding + # 3. Calculate the valid ratio of each feature map based on the mask + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, poss)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + mask = mask.flatten(1) # bs, hw + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + if self.num_feature_levels > 1 and self.transformer.level_embed is not None: + lvl_pos_embed = pos_embed + self.transformer.level_embed[lvl].view( + 1, 1, -1 + ) + else: + lvl_pos_embed = pos_embed + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device + ) + level_start_index = torch.cat( + (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) + ) + valid_ratios = torch.stack( + [self.transformer.get_valid_ratio(m) for m in masks], 1 + ) + + return ( + src_flatten, + lvl_pos_embed_flatten, + level_start_index, + spatial_shapes, + valid_ratios, + mask_flatten, + ) + + def prepare_vision_feature_projection_layer( + self, + backbone: nn.Module, + num_feature_levels: int, + hidden_dim: int, + two_stage_type: str, + ) -> nn.ModuleList: + """Prepare projection layer to map backbone feature to hidden dim. + + Args: + backbone (nn.Module): Backbone. + num_feature_levels (int): Number of feature levels. + hidden_dim (int): Hidden dim. + two_stage_type (str): Type of two stage. + + Returns: + nn.ModuleList: Projection layer. + """ + if num_feature_levels > 1: + num_backbone_outs = len(backbone.num_channels) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d( + in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 + ), + nn.GroupNorm(32, hidden_dim), + ) + ) + in_channels = hidden_dim + input_proj = nn.ModuleList(input_proj_list) + else: + assert ( + two_stage_type == "no" + ), "two_stage_type should be no if num_feature_levels=1 !!!" + input_proj = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ] + ) + return input_proj + + def prepare_prediction_head( + self, + dec_pred_class_embed_share: bool, + dec_pred_bbox_embed_share: bool, + hidden_dim: int, + num_decoder_layers: int, + ) -> Union[nn.ModuleList, nn.ModuleList]: + """Prepare prediction head. Including class embed and bbox embed. + + Args: + dec_pred_class_embed_share (bool): Whether to share class embed for all decoder layers. + dec_pred_bbox_embed_share (bool): Whether to share bbox embed for all decoder layers. + im (int): Hidden dim. + num_decoder_layers (int): Number of decoder layers. + + """ + _class_embed = ContrastiveAssign() + _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) + if dec_pred_bbox_embed_share: + box_embed_layerlist = [_bbox_embed for _ in range(num_decoder_layers)] + else: + box_embed_layerlist = [ + copy.deepcopy(_bbox_embed) for i in range(num_decoder_layers) + ] + if dec_pred_class_embed_share: + class_embed_layerlist = [_class_embed for i in range(num_decoder_layers)] + else: + class_embed_layerlist = [ + copy.deepcopy(_class_embed) for i in range(num_decoder_layers) + ] + bbox_embed = nn.ModuleList(box_embed_layerlist) + class_embed = nn.ModuleList(class_embed_layerlist) + self.bbox_embed = bbox_embed + self.class_embed = class_embed + + # iniitalize bbox embed and class embed in transformer + self.transformer.decoder.bbox_embed = bbox_embed + self.transformer.decoder.class_embed = class_embed + + if self.transformer.two_stage_type != "no": + if self.transformer.two_stage_bbox_embed_share: + assert dec_pred_class_embed_share and dec_pred_bbox_embed_share + self.transformer.enc_out_bbox_embed = _bbox_embed + else: + self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) + + if self.transformer.two_stage_class_embed_share: + assert dec_pred_class_embed_share and dec_pred_bbox_embed_share + self.transformer.enc_out_class_embed = _class_embed + else: + self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) + + self.refpoint_embed = None + + def _reset_parameters(self): + # init input_proj + for proj in self.input_proj: + + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) diff --git a/detect_tools/upn/models/backbone/__init__.py b/detect_tools/upn/models/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9aff777c6f0381797208d7ba60d7624553347e --- /dev/null +++ b/detect_tools/upn/models/backbone/__init__.py @@ -0,0 +1,4 @@ +from .swin import SwinTransformer +from .wrapper import SwinWrapper + +__all__ = ["SwinWrapper", "SwinTransformer"] diff --git a/detect_tools/upn/models/backbone/swin.py b/detect_tools/upn/models/backbone/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..4874a33a10dc7708969e460aaebc11107226ad54 --- /dev/null +++ b/detect_tools/upn/models/backbone/swin.py @@ -0,0 +1,814 @@ +from typing import Dict, List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detect_tools.upn import BACKBONES +from detect_tools.upn.models.module import NestedTensor + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=( + drop_path[i] if isinstance(drop_path, list) else drop_path + ), + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +@BACKBONES.register_module() +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + dilation (bool): if True, the output size if 16x downsample, ow 32x downsample. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + dilation=False, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.dilation = dilation + + if use_checkpoint: + print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!") + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + # prepare downsample list + downsamplelist = [PatchMerging for i in range(self.num_layers)] + downsamplelist[-1] = None + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + if self.dilation: + downsamplelist[-2] = None + num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2 + for i_layer in range(self.num_layers): + layer = BasicLayer( + # dim=int(embed_dim * 2 ** i_layer), + dim=num_features[i_layer], + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + downsample=downsamplelist[i_layer], + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward_raw(self, x: torch.Tensor) -> List[torch.Tensor]: + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + # import ipdb; ipdb.set_trace() + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = ( + x_out.view(-1, H, W, self.num_features[i]) + .permute(0, 3, 1, 2) + .contiguous() + ) + outs.append(out) + + return tuple(outs) + + def forward(self, tensor_list: NestedTensor) -> Dict: + """Forward function. + + Args: + tensor_list (NestedTensor): NestedTensor object containing tensors and masks. + + Returns: + Dict: Dict containing output tensors. The structure is as follows. + - 0: NestedTensor from stage 0. + - 1: NestedTensor from stage 1. + - 2: NestedTensor from stage 2. + - 3: NestedTensor from stage 3. + """ + x = tensor_list.tensors + + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = ( + x_out.view(-1, H, W, self.num_features[i]) + .permute(0, 3, 1, 2) + .contiguous() + ) + outs.append(out) + + # collect for nesttensors + outs_dict = {} + for idx, out_i in enumerate(outs): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[ + 0 + ] + outs_dict[idx] = NestedTensor(out_i, mask) + + return outs_dict + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +def build_swin_transformer(modelname, pretrain_img_size, **kw): + assert modelname in [ + "swin_T_224_1k", + "swin_B_224_22k", + "swin_B_384_22k", + "swin_L_224_22k", + "swin_L_384_22k", + ] + + model_para_dict = { + "swin_T_224_1k": dict( + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7 + ), + "swin_B_224_22k": dict( + embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7 + ), + "swin_B_384_22k": dict( + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + ), + "swin_L_224_22k": dict( + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + ), + "swin_L_384_22k": dict( + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + ), + } + kw_cgf = model_para_dict[modelname] + kw_cgf.update(kw) + model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf) + return model diff --git a/detect_tools/upn/models/backbone/wrapper.py b/detect_tools/upn/models/backbone/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe7a4c8d4742418a480efe9e74b86fce0ef2cee --- /dev/null +++ b/detect_tools/upn/models/backbone/wrapper.py @@ -0,0 +1,297 @@ +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn as nn + +from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding +from detect_tools.upn.models.module import NestedTensor +from detect_tools.upn.models.utils import clean_state_dict + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class Joiner(nn.Module): + """A wrapper for the backbone and the position embedding. + + Args: + backbone_cfg (Dict): Config dict to build backbone. + position_embedding_cfg (Dict): Config dict to build position embedding. + """ + + def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None: + super().__init__() + self.backbone = backbone + self.pos_embed = position_embedding + + def forward( + self, tensor_list: NestedTensor + ) -> Union[List[NestedTensor], List[torch.Tensor]]: + """Forward function. + + Args: + tensor_list (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + [List[NestedTensor]: A list of feature map in NestedTensor format. + List[torch.Tensor]: A list of position encoding. + """ + + xs = self.backbone(tensor_list) + out: List[NestedTensor] = [] + pos = [] + for layer_idx, x in xs.items(): + out.append(x) + # position encoding + pos.append(self.pos_embed(x).to(x.tensors.dtype)) + + return out, pos + + def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor: + """Forward function for position embedding only. This is used to generate additional layer + + Args: + x (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + [List[torch.Tensor]: A list of position encoding. + """ + return self.pos_embed(x) + + +@BACKBONES.register_module() +class SwinWrapper(nn.Module): + """A wrapper for swin transformer. + + Args: + backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we + will call `get_swin_config` to get the config dict. + dilation (bool): Whether to use dilation in stage 4. + position_embedding_cfg (Dict): Config dict to build position embedding. + lr_backbone (float): Learning rate of the backbone. + return_interm_layers (List[int]): Which layers to return. + backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone. + use_checkpoint (bool): Whether to use checkpoint. Default: False. + ckpt_path (str): Checkpoint path. Default: None. + use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True. + """ + + def __init__( + self, + backbone_cfg: Union[Dict, str], + dilation: bool, + position_embedding_cfg: Dict, + lr_backbone: float, + return_interm_indices: List[int], + backbone_freeze_keywords: List[str], + use_checkpoint: bool = False, + backbone_ckpt_path: str = None, + ) -> None: + super(SwinWrapper, self).__init__() + pos_embedding = build_position_embedding(position_embedding_cfg) + train_backbone = lr_backbone > 0 + if not train_backbone: + raise ValueError("Please set lr_backbone > 0") + assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] + + # build backbone + if isinstance(backbone_cfg, str): + assert ( + backbone_cfg + in backbone_cfg + in [ + "swin_T_224_1k", + "swin_B_224_22k", + "swin_B_384_22k", + "swin_L_224_22k", + "swin_L_384_22k", + ] + ) + pretrain_img_size = int(backbone_cfg.split("_")[-2]) + backbone_cfg = get_swin_config( + backbone_cfg, + pretrain_img_size, + out_indices=tuple(return_interm_indices), + dilation=dilation, + use_checkpoint=use_checkpoint, + ) + backbone = build_backbone(backbone_cfg) + + # freeze some layers + if backbone_freeze_keywords is not None: + for name, parameter in backbone.named_parameters(): + for keyword in backbone_freeze_keywords: + if keyword in name: + parameter.requires_grad_(False) + break + + # load checkpoint + if backbone_ckpt_path is not None: + print("Loading backbone checkpoint from {}".format(backbone_ckpt_path)) + checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"] + from collections import OrderedDict + + def key_select_function(keyname): + if "head" in keyname: + return False + if dilation and "layers.3" in keyname: + return False + return True + + _tmp_st = OrderedDict( + { + k: v + for k, v in clean_state_dict(checkpoint).items() + if key_select_function(k) + } + ) + _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False) + print(str(_tmp_st_output)) + + bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] + assert len(bb_num_channels) == len( + return_interm_indices + ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" + + model = Joiner(backbone, pos_embedding) + model.num_channels = bb_num_channels + self.num_channels = bb_num_channels + self.model = model + + def forward( + self, tensor_list: NestedTensor + ) -> Union[List[NestedTensor], List[torch.Tensor]]: + """Forward function. + + Args: + tensor_list (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + [List[NestedTensor]: A list of feature map in NestedTensor format. + List[torch.Tensor]: A list of position encoding. + """ + + return self.model(tensor_list) + + def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor: + """Forward function to get position embedding only. + + Args: + tensor_list (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + torch.Tensor: Position embedding. + """ + return self.model.forward_pos_embed_only(tensor_list) + + +def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw): + """Get swin config dict. + + Args: + modelname (str): Name of the model. + pretrain_img_size (Tuple[int, int]): Image size of the pretrain model. + kw (Dict): Other key word arguments. + + Returns: + Dict: Config dict. + str: Path to the pretrained checkpoint. + """ + assert modelname in [ + "swin_T_224_1k", + "swin_B_224_22k", + "swin_B_384_22k", + "swin_L_224_22k", + "swin_L_384_22k", + ] + model_para_dict = { + "swin_T_224_1k": dict( + type="SwinTransformer", + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + ), + "swin_B_224_22k": dict( + type="SwinTransformer", + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + ), + "swin_B_384_22k": dict( + type="SwinTransformer", + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + ), + "swin_L_224_22k": dict( + type="SwinTransformer", + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + ), + "swin_L_384_22k": dict( + type="SwinTransformer", + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + ), + } + kw_cgf = model_para_dict[modelname] + kw_cgf.update(kw) + kw_cgf.update(dict(pretrain_img_size=pretrain_img_size)) + return kw_cgf diff --git a/detect_tools/upn/models/decoder/__init__.py b/detect_tools/upn/models/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96f11bd3cc1b089d0497c6ea2a5a613c9d2927c9 --- /dev/null +++ b/detect_tools/upn/models/decoder/__init__.py @@ -0,0 +1,3 @@ +from .upn_decoder import UPNDecoder, DeformableTransformerDecoderLayer + +__all__ = ["UPNDecoder", "DeformableTransformerDecoderLayer"] diff --git a/detect_tools/upn/models/decoder/upn_decoder.py b/detect_tools/upn/models/decoder/upn_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..117067ce4867c998ca77fe82169c228f5ecc3060 --- /dev/null +++ b/detect_tools/upn/models/decoder/upn_decoder.py @@ -0,0 +1,378 @@ +from typing import Dict + +import torch +import torch.nn as nn + +from detect_tools.upn import DECODERS, build_decoder +from detect_tools.upn.models.module import MLP +from detect_tools.upn.models.utils import (gen_sineembed_for_position, + get_activation_fn, get_clones, + inverse_sigmoid) +from detect_tools.upn.ops.modules import MSDeformAttn + + +@DECODERS.register_module() +class DeformableTransformerDecoderLayer(nn.Module): + """Deformable Transformer Decoder Layer. This is a modified version in Grounding DINO. + After the query is attented to the image feature, it is further attented to the text feature. + The execute order is: self_attn -> cross_attn to text -> cross_attn to image -> ffn + Args: + d_model (int): The dimension of keys/values/queries in :class:`MultiheadAttention`. + d_ffn (int): The dimension of the feedforward network model. + dropout (float): Probability of an element to be zeroed. + activation (str): Activation function in the feedforward network. + 'relu' and 'gelu' are supported. + n_levels (int): The number of levels in Multi-Scale Deformable Attention. + n_heads (int): Parallel attention heads. + n_points (int): Number of sampling points in Multi-Scale Deformable Attention. + ffn_extra_layernorm (bool): If True, add an extra layernorm after ffn. + """ + + def __init__( + self, + d_model: int = 256, + d_ffn: int = 1024, + dropout: float = 0.1, + activation: str = "relu", + n_levels: int = 4, + n_heads: int = 8, + n_points: int = 4, + ffn_extra_layernorm: bool = False, + ) -> None: + super().__init__() + + # cross attention for visual features + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm1 = nn.LayerNorm(d_model) + + # self attention for query + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = get_activation_fn(activation, d_model=d_ffn, batch_dim=1) + self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm3 = nn.LayerNorm(d_model) + if ffn_extra_layernorm: + raise NotImplementedError("ffn_extra_layernorm not implemented") + self.norm_ext = nn.LayerNorm(d_ffn) + else: + self.norm_ext = None + + self.key_aware_proj = None + + def rm_self_attn_modules(self): + self.self_attn = None + self.dropout2 = None + self.norm2 = None + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward( + self, + tgt: torch.Tensor, + tgt_query_pos: torch.Tensor = None, + tgt_reference_points: torch.Tensor = None, + memory: torch.Tensor = None, + memory_key_padding_mask: torch.Tensor = None, + memory_level_start_index: torch.Tensor = None, + memory_spatial_shapes: torch.Tensor = None, + self_attn_mask: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + ) -> torch.Tensor: + """Forward function + + Args: + tgt (torch.Tensor): Input target in shape (B, T, C) + tgt_query_pos (torch.Tensor): Positional encoding of the query. + tgt_query_sine_embed (torch.Tensor): Sine positional encoding of the query. Unused. + tgt_key_padding_mask (torch.Tensor): Mask for target feature in shape (B, T). + tgt_reference_points (torch.Tensor): Reference points for the query in shape (B, T, 4). + memory_text (torch.Tensor): Input text embeddings in shape (B, num_token, C). + text_attention_mask (torch.Tensor): Attention mask for text embeddings in shape + (B, num_token). + memory (torch.Tensor): Input image feature in shape (B, HW, C) + memory_key_padding_mask (torch.Tensor): Mask for image feature in shape (B, HW) + memory_level_start_index (torch.Tensor): Starting index of each level in memory. + memory_spatial_shapes (torch.Tensor): Spatial shape of each level in memory. + memory_pos (torch.Tensor): Positional encoding of memory. Unused. + self_attn_mask (torch.Tensor): Mask used for self-attention. + cross_attn_mask (torch.Tensor): Mask used for cross-attention. + + Returns: + torch.Tensor: Output tensor in shape (B, T, C) + """ + assert cross_attn_mask is None + + # self attention + if self.self_attn is not None: + q = k = self.with_pos_embed(tgt, tgt_query_pos) + tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # attend to image features + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), + tgt_reference_points.transpose(0, 1).contiguous(), + memory.transpose(0, 1), + memory_spatial_shapes, + memory_level_start_index, + memory_key_padding_mask, + ).transpose(0, 1) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +@DECODERS.register_module() +class UPNDecoder(nn.Module): + """Decoder used in UPN. Each layer is a DeformableTransformerDecoderLayer. The query + will be abled to attend the image feature and text feature. The execute order is: + self_attn -> cross_attn to image -> ffn + + Args: + decoder_layer_cfg (Dict): Config for the DeformableTransformerDecoderLayer. + num_layers (int): number of layers + norm (nn.Module, optional): normalization layer. Defaults to None. + return_intermediate (bool, optional): whether return intermediate results. + Defaults to False. + d_model (int, optional): dimension of the model. Defaults to 256. + query_dim (int, optional): dimension of the query. Defaults to 4. + modulate_hw_attn (bool, optional): whether modulate the attention weights + by the height and width of the image feature. Defaults to False. + num_feature_levels (int, optional): number of feature levels. Defaults to 1. + deformable_decoder (bool, optional): whether use deformable decoder. Defaults to False. + decoder_query_perturber ([type], optional): [description]. Defaults to None. + dec_layer_number ([type], optional): [description]. Defaults to None. + rm_dec_query_scale (bool, optional): [description]. Defaults to False. + dec_layer_share (bool, optional): [description]. Defaults to False. + dec_layer_dropout_prob ([type], optional): [description]. Defaults to None. + """ + + def __init__( + self, + decoder_layer_cfg: Dict, + num_layers: int, + norm: str = "layernorm", + return_intermediate: bool = True, + d_model: int = 256, + query_dim: int = 4, + modulate_hw_attn: bool = False, + num_feature_levels: int = 1, + deformable_decoder: bool = True, + decoder_query_perturber=None, + dec_layer_number=None, + rm_dec_query_scale: bool = True, + dec_layer_share: bool = False, + dec_layer_dropout_prob=None, + use_detached_boxes_dec_out: bool = False, + ): + super().__init__() + + decoder_layer = build_decoder(decoder_layer_cfg) + if num_layers > 0: + self.layers = get_clones( + decoder_layer, num_layers, layer_share=dec_layer_share + ) + else: + self.layers = [] + self.num_layers = num_layers + if norm == "layernorm": + self.norm = nn.LayerNorm(d_model) + self.return_intermediate = return_intermediate + self.query_dim = query_dim + assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) + self.num_feature_levels = num_feature_levels + self.use_detached_boxes_dec_out = use_detached_boxes_dec_out + + self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) + self.ref_point_head_point = MLP( + d_model, d_model, d_model, 2 + ) # for point reference only + if not deformable_decoder: + self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) + else: + self.query_pos_sine_scale = None + + if rm_dec_query_scale: + self.query_scale = None + else: + raise NotImplementedError + self.query_scale = MLP(d_model, d_model, d_model, 2) + self.bbox_embed = None + self.class_embed = None + + self.d_model = d_model + self.modulate_hw_attn = modulate_hw_attn + self.deformable_decoder = deformable_decoder + + if not deformable_decoder and modulate_hw_attn: + self.ref_anchor_head = MLP(d_model, d_model, 2, 2) + else: + self.ref_anchor_head = None + + self.decoder_query_perturber = decoder_query_perturber + self.box_pred_damping = None + + self.dec_layer_number = dec_layer_number + if dec_layer_number is not None: + assert isinstance(dec_layer_number, list) + assert len(dec_layer_number) == num_layers + + self.dec_layer_dropout_prob = dec_layer_dropout_prob + if dec_layer_dropout_prob is not None: + assert isinstance(dec_layer_dropout_prob, list) + assert len(dec_layer_dropout_prob) == num_layers + for i in dec_layer_dropout_prob: + assert 0.0 <= i <= 1.0 + + self.rm_detach = None + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: torch.Tensor = None, + memory_mask: torch.Tensor = None, + tgt_key_padding_mask: torch.Tensor = None, + memory_key_padding_mask: torch.Tensor = None, + pos: torch.Tensor = None, + refpoints_unsigmoid: torch.Tensor = None, + level_start_index: torch.Tensor = None, + spatial_shapes: torch.Tensor = None, + valid_ratios: torch.Tensor = None, + memory_ref_image: torch.Tensor = None, + refImg_padding_mask: torch.Tensor = None, + memory_visual_prompt: torch.Tensor = None, + ): + """Forward function. + + Args: + tgt (torch.Tensor): target feature, [bs, num_queries, d_model] + memory (torch.Tensor): Image feature, [bs, hw, d_model] + tgt_mask (torch.Tensor, optional): target mask for attention. Defaults to None. + memory_mask (torch.Tensor, optional): image mask for attention. Defaults to None. + tgt_key_padding_mask (torch.Tensor, optional): target mask for padding. Defaults to None. + memory_key_padding_mask (torch.Tensor, optional): image mask for padding. Defaults to None. + pos (torch.Tensor, optional): query position embedding + refpoints_unsigmoid (torch.Tensor, optional): reference points. Defaults to None. + level_start_index (torch.Tensor, optional): start index of each level. Defaults to None. + spatial_shapes (torch.Tensor, optional): spatial shape of each level. Defaults to None. + valid_ratios (torch.Tensor, optional): valid ratio of each level. Defaults to None. + memory_ref_image (torch.Tensor, optional): reference image feature, [bs, num_ref, d_model]. Defaults to None. + refImg_padding_mask (torch.Tensor, optional): padding mask for attention. Defaults to None. + """ + output = tgt + + intermediate = [] + reference_points = refpoints_unsigmoid.sigmoid() + ref_points = [reference_points] + + for layer_id, layer in enumerate(self.layers): + + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] + * torch.cat([valid_ratios, valid_ratios], -1)[None, :] + ) # nq, bs, nlevel, 4 + else: + assert reference_points.shape[-1] == 2 + reference_points_input = ( + reference_points[:, :, None] * valid_ratios[None, :] + ) + query_sine_embed = gen_sineembed_for_position( + reference_points_input[:, :, 0, :] + ) # nq, bs, 256*2 + + # conditional query + if query_sine_embed.shape[-1] == 512: + raw_query_pos = ( + self.ref_point_head(query_sine_embed) + + self.ref_point_head_point( + torch.zeros_like(query_sine_embed)[:, :, :256] + ) + * 0.0 + ) + else: + raw_query_pos = ( + self.ref_point_head_point(query_sine_embed) + + self.ref_point_head( + torch.zeros( + query_sine_embed.shape[0], + query_sine_embed.shape[1], + 512, + device=query_sine_embed.device, + ) + ) + * 0.0 + ) + pos_scale = self.query_scale(output) if self.query_scale is not None else 1 + query_pos = pos_scale * raw_query_pos + + # main process + output = layer( + tgt=output, + tgt_query_pos=query_pos, + tgt_reference_points=reference_points_input, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + memory_level_start_index=level_start_index, + memory_spatial_shapes=spatial_shapes, + self_attn_mask=tgt_mask, + cross_attn_mask=memory_mask, + ) + if output.isnan().any() | output.isinf().any(): + print(f"output layer_id {layer_id} is nan") + try: + num_nan = output.isnan().sum().item() + num_inf = output.isinf().sum().item() + print(f"num_nan {num_nan}, num_inf {num_inf}") + except Exception as e: + print(e) + + # iter update + if self.bbox_embed is not None: + + reference_before_sigmoid = inverse_sigmoid(reference_points) + delta_unsig = self.bbox_embed[layer_id](output) + outputs_unsig = delta_unsig + reference_before_sigmoid + new_reference_points = outputs_unsig.sigmoid() + + if self.rm_detach and "dec" in self.rm_detach: + reference_points = new_reference_points + else: + reference_points = new_reference_points.detach() + + if self.use_detached_boxes_dec_out: + ref_points.append(reference_points) + else: + ref_points.append(new_reference_points) + + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.return_intermediate: + return [ + [itm_out.transpose(0, 1) for itm_out in intermediate], + [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], + ] + else: + return self.norm(output).transpose(0, 1) diff --git a/detect_tools/upn/models/encoder/__init__.py b/detect_tools/upn/models/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..907595e2c853dbd041014a82efa800612644f8d3 --- /dev/null +++ b/detect_tools/upn/models/encoder/__init__.py @@ -0,0 +1,3 @@ +from .upn_encoder import DeformableTransformerEncoderLayer, UPNEncoder + +__all__ = ["UPNEncoder", "DeformableTransformerEncoderLayer"] diff --git a/detect_tools/upn/models/encoder/upn_encoder.py b/detect_tools/upn/models/encoder/upn_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ae338fd0e2d34cb48c556524552265974d2adf5c --- /dev/null +++ b/detect_tools/upn/models/encoder/upn_encoder.py @@ -0,0 +1,288 @@ +from typing import Dict + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from detect_tools.upn import ENCODERS, build_encoder +from detect_tools.upn.models.utils import get_activation_fn, get_clones +from detect_tools.upn.ops.modules import MSDeformAttn + + +@ENCODERS.register_module() +class DeformableTransformerEncoderLayer(nn.Module): + """Deformable Transformer Encoder Layer. + + Args: + d_model (int): The dimension of keys/values/queries in + :class:`MultiheadAttention`. + d_ffn (int): The dimension of the feedforward network model. + dropout (float): Probability of an element to be zeroed. + activation (str): Activation function in the feedforward network. + 'relu' and 'gelu' are supported. + n_levels (int): The number of levels in Multi-Scale Deformable Attention. + n_heads (int): Parallel attention heads. + n_points (int): Number of sampling points in Multi-Scale Deformable Attention. + add_channel_attention (bool): If True, add channel attention. + """ + + def __init__( + self, + d_model: int = 256, + d_ffn: int = 1024, + dropout: float = 0.1, + activation: str = "relu", + n_levels: int = 4, + n_heads: int = 8, + n_points: int = 4, + add_channel_attention: bool = False, + ) -> None: + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = get_activation_fn(activation, d_model=d_ffn) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # channel attention + self.add_channel_attention = add_channel_attention + if add_channel_attention: + self.activ_channel = get_activation_fn("dyrelu", d_model=d_model) + self.norm_channel = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src: torch.Tensor) -> torch.Tensor: + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward( + self, + src: torch.Tensor, + pos: torch.Tensor, + reference_points: torch.Tensor, + spatial_shapes: torch.Tensor, + level_start_index: torch.Tensor, + key_padding_mask: torch.Tensor = None, + ) -> torch.Tensor: + """Forward function for `DeformableTransformerEncoderLayer`. + + Args: + src (torch.Tensor): The input sequence of shape (S, N, E). + pos (torch.Tensor): The position embedding of shape (S, N, E). + reference_points (torch.Tensor): The reference points of shape (N, L, 2). + spatial_shapes (torch.Tensor): The spatial shapes of feature levels. + level_start_index (torch.Tensor): The start index of each level. + key_padding_mask (torch.Tensor): The mask for keys with shape (N, S). + """ + # self attention + # import ipdb; ipdb.set_trace() + src2 = self.self_attn( + self.with_pos_embed(src, pos), + reference_points, + src, + spatial_shapes, + level_start_index, + key_padding_mask, + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + # channel attn + if self.add_channel_attention: + src = self.norm_channel(src + self.activ_channel(src)) + + return src + + +@ENCODERS.register_module() +class UPNEncoder(nn.Module): + """Implementation of UPN Encoder. + + Args: + num_layers (int): The number of layers in the TransformerEncoder. + d_model (int, optional): The dimension of the input feature. Defaults to 256. + encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer. + use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for + memory saving. Defaults to False. + use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder. + enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer. + Defaults to False. This is used for all the sub-layers in the basic block. + """ + + def __init__( + self, + num_layers: int, + d_model: int = 256, + encoder_layer_cfg: Dict = None, + use_checkpoint: bool = True, + use_transformer_ckpt: bool = True, + enc_layer_share: bool = False, + multi_level_encoder_fusion: str = None, + ): + super().__init__() + # prepare layers + self.layers = [] + self.refImg_layers = [] + self.fusion_layers = [] + encoder_layer = build_encoder(encoder_layer_cfg) + + self.multi_level_encoder_fusion = multi_level_encoder_fusion + self._initilize_memory_fusion_layers( + multi_level_encoder_fusion, num_layers, d_model + ) + + if num_layers > 0: + self.layers = get_clones( + encoder_layer, num_layers, layer_share=enc_layer_share + ) + else: + self.layers = [] + del encoder_layer + + self.query_scale = None + self.num_layers = num_layers + self.d_model = d_model + + self.use_checkpoint = use_checkpoint + self.use_transformer_ckpt = use_transformer_ckpt + + def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model): + if fusion_type is None: + self.memory_fusion_layer = None + return + + assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"] + if fusion_type == "stable_dense_fusion": + self.memory_fusion_layer = nn.Sequential( + nn.Linear(d_model * (num_layers + 1), d_model), + nn.LayerNorm(d_model), + ) + nn.init.constant_(self.memory_fusion_layer[0].bias, 0) + elif fusion_type == "dense_net_fusion": + self.memory_fusion_layer = nn.ModuleList() + for i in range(num_layers): + self.memory_fusion_layer.append( + nn.Sequential( + nn.Linear( + d_model * (i + 2), d_model + ), # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256 + nn.LayerNorm(d_model), + ) + ) + for layer in self.memory_fusion_layer: + nn.init.constant_(layer[0].bias, 0) + else: + raise NotImplementedError + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + src: torch.Tensor, + pos: torch.Tensor, + spatial_shapes: torch.Tensor, + level_start_index: torch.Tensor, + valid_ratios: torch.Tensor, + key_padding_mask: torch.Tensor = None, + ): + """Forward function + + Args: + src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256] + pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256] + spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2] + level_start_index (torch.Tensor): Start index of each level in shape [num_level] + valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2] + key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)] + memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults + to None. + refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature + in shape [bs, n_text]. Defaults to None. + pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape + [bs, n_ref, 256]. Defaults to None. + refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference + image feature in shape [bs, n_ref, n_ref]. Defaults to None. + Outpus: + torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256] + torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256] + """ + + output = src + # preparation and reshape + if self.num_layers > 0: + reference_points = self.get_reference_points( + spatial_shapes, valid_ratios, device=src.device + ) + + # multi-level dense fusion + output_list = [output] + # main process + for layer_id, layer in enumerate(self.layers): + # main process + if self.use_transformer_ckpt: + output = checkpoint.checkpoint( + layer, + output, + pos, + reference_points, + spatial_shapes, + level_start_index, + key_padding_mask, + ) + else: + output = layer( + src=output, + pos=pos, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=key_padding_mask, + ) + + output_list.append(output) + if ( + self.multi_level_encoder_fusion is not None + and self.multi_level_encoder_fusion == "dense_net_fusion" + ): + output = self.memory_fusion_layer[layer_id]( + torch.cat(output_list, dim=-1) + ) + + if ( + self.multi_level_encoder_fusion is not None + and self.multi_level_encoder_fusion == "stable_dense_fusion" + ): + output = self.memory_fusion_layer(torch.cat(output_list, dim=-1)) + + return output diff --git a/detect_tools/upn/models/module/__init__.py b/detect_tools/upn/models/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03a4509a5f64f33feb9dfc9245e65a2a67e635a8 --- /dev/null +++ b/detect_tools/upn/models/module/__init__.py @@ -0,0 +1,5 @@ +from .contrastive import ContrastiveAssign +from .mlp import MLP +from .nested_tensor import NestedTensor, nested_tensor_from_tensor_list + +__all__ = ["MLP", "NestedTensor", "nested_tensor_from_tensor_list", "ContrastiveAssign"] diff --git a/detect_tools/upn/models/module/contrastive.py b/detect_tools/upn/models/module/contrastive.py new file mode 100644 index 0000000000000000000000000000000000000000..7830f353738d2fd10c23e386cd2a42188b0373e4 --- /dev/null +++ b/detect_tools/upn/models/module/contrastive.py @@ -0,0 +1,29 @@ +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ContrastiveAssign(nn.Module): + + def __init__( + self, + cal_bias: nn.Module = None, + ) -> None: + """Lanuage-Image Contrastive Assignment used to calculate the similarity between + the text and the image. + + Args: + cal_bias (nn.Module, optional): The bias used to calculate the similarity. + Defaults to None. + max_text_len (int, optional): The max length of the text. Defaults to 256. + """ + super().__init__() + self.cal_bias = cal_bias + + def forward(self, x: torch.Tensor, ref_dict: Dict): + + y = ref_dict["encoded_ref_feature"] + res = x @ y.transpose(-1, -2) + return res diff --git a/detect_tools/upn/models/module/mlp.py b/detect_tools/upn/models/module/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..4b654715a33abc42df5b29aaa52cfeac2664941c --- /dev/null +++ b/detect_tools/upn/models/module/mlp.py @@ -0,0 +1,18 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/detect_tools/upn/models/module/nested_tensor.py b/detect_tools/upn/models/module/nested_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..8655ebb23fae2003141e0933e3952ec979edeb65 --- /dev/null +++ b/detect_tools/upn/models/module/nested_tensor.py @@ -0,0 +1,199 @@ +from typing import List, Union + +import torch +import torchvision + + +class NestedTensor(object): + """Define a NestedTensor class + + Args: + tensors (torch.Tensor): Tensor with shape [batch, C, H, W] or [C, H, W] + mask (Union[torch.Tensor, str]): mask with shape [batch, H, W] or [H, W]. If mask + is 'auto', it will be generated automatically by summing the tensor along + the channel dimension. Mask is used to indicate the padding area. + """ + + def __init__( + self, tensors: torch.Tensor, mask: Union[torch.Tensor, str] = "auto" + ) -> None: + self.tensors = tensors + self.mask = mask + if mask == "auto": + self.mask = torch.zeros_like(tensors).to(tensors.device) + if self.mask.dim() == 3: + self.mask = self.mask.sum(0).to(bool) + elif self.mask.dim() == 4: + self.mask = self.mask.sum(1).to(bool) + else: + raise ValueError( + "tensors dim must be 3 or 4 but {}({})".format( + self.tensors.dim(), self.tensors.shape + ) + ) + + def imgsize(self) -> List[torch.Tensor]: + """get the img size of the tensor + + Returns: + list[torch.Tensor]: list of tensor with shape [2] which is [H, W] + """ + res = [] + for i in range(self.tensors.shape[0]): + mask = self.mask[i] + maxH = (~mask).sum(0).max() + maxW = (~mask).sum(1).max() + res.append(torch.Tensor([maxH, maxW])) + return res + + def to(self, device: torch.device): + """Move tensors and mask to the given device + + Args: + device (torch.device): device to move + + Returns: + NestedTensor: moved NestedTensor + """ + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def to_img_list_single( + self, tensor: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """remove the padding for one image + + Args: + tensor (torch.Tensor): tensor with shape [C, H, W] + mask (torch.Tensor): mask with shape [H, W] + + Returns: + torch.Tensor: tensor with shape [C, maxH, maxW] + """ + assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format( + tensor.dim() + ) + maxH = (~mask).sum(0).max() + maxW = (~mask).sum(1).max() + img = tensor[:, :maxH, :maxW] + return img + + def to_img_list(self) -> List[torch.Tensor]: + """remove the padding and convert to img list + + Returns: + list[torch.Tensor]: list of tensor with shape [C, maxH, maxW] + """ + if self.tensors.dim() == 3: + return self.to_img_list_single(self.tensors, self.mask) + else: + res = [] + for i in range(self.tensors.shape[0]): + tensor_i = self.tensors[i] + mask_i = self.mask[i] + res.append(self.to_img_list_single(tensor_i, mask_i)) + return res + + @property + def device(self): + return self.tensors.device + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + @property + def shape(self): + return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape} + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list( + tensor_list: List[torch.Tensor], fixed_img_size=None +): + if fixed_img_size is not None: + if isinstance(fixed_img_size, (list, tuple)): + assert ( + len(fixed_img_size) == 2 + ), "image size should be a tuple or list with two elements" + elif isinstance(fixed_img_size, int): + fixed_img_size = [fixed_img_size, fixed_img_size] + + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + + if fixed_img_size is not None: + c, orig_h, orig_w = max_size + assert ( + orig_h <= fixed_img_size[0] and orig_w <= fixed_img_size[1] + ), f"{orig_h} {orig_w} the fixed output image size should be larger than original image" + max_size = [c, fixed_img_size[0], fixed_img_size[1]] + + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list( + tensor_list: List[torch.Tensor], +) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad( + img, (0, padding[2], 0, padding[1], 0, padding[0]) + ) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad( + m, (0, padding[2], 0, padding[1]), "constant", 1 + ) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) diff --git a/detect_tools/upn/models/utils/__init__.py b/detect_tools/upn/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95de7a1bb95bfb5807691520e8a897cd3ac8d83e --- /dev/null +++ b/detect_tools/upn/models/utils/__init__.py @@ -0,0 +1,23 @@ +from .detr_utils import ( + PositionEmbeddingLearned, + PositionEmbeddingSine, + PositionEmbeddingSineHW, + clean_state_dict, + gen_encoder_output_proposals, + gen_sineembed_for_position, + get_activation_fn, + get_clones, + inverse_sigmoid, +) + +__all__ = [ + "inverse_sigmoid", + "gen_encoder_output_proposals", + "get_clones", + "gen_sineembed_for_position", + "get_activation_fn", + "clean_state_dict", + "PositionEmbeddingSine", + "PositionEmbeddingSineHW", + "PositionEmbeddingLearned", +] diff --git a/detect_tools/upn/models/utils/detr_utils.py b/detect_tools/upn/models/utils/detr_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4af48d3f042e1a10ff2484ecd8069f16e5a6404 --- /dev/null +++ b/detect_tools/upn/models/utils/detr_utils.py @@ -0,0 +1,415 @@ +import copy +import math +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import nn + +from detect_tools.upn import POS_EMBEDDINGS +from detect_tools.upn.models.module import NestedTensor + + +@POS_EMBEDDINGS.register_module() +class PositionEmbeddingSine(nn.Module): + """This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + + Args: + num_pos_feats (int): The channel of positional embeddings. + temperature (float): The temperature used in positional embeddings. + normalize (bool): Whether to normalize the positional embeddings. + scale (float): The scale factor of positional embeddings. + """ + + def __init__( + self, + num_pos_feats: int = 64, + temperature: int = 10000, + normalize: bool = False, + scale: float = None, + ) -> None: + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor) -> torch.Tensor: + """Forward function. + + Args: + tensor_list (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W) + """ + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +@POS_EMBEDDINGS.register_module() +class PositionEmbeddingSineHW(nn.Module): + """This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + + Args: + num_pos_feats (int): The channel of positional embeddings. + temperatureH (float): The temperature used in positional embeddings. + temperatureW (float): The temperature used in positional embeddings. + normalize (bool): Whether to normalize the positional embeddings. + scale (float): The scale factor of positional embeddings. + """ + + def __init__( + self, + num_pos_feats: int = 64, + temperatureH: int = 10000, + temperatureW: int = 10000, + normalize: bool = False, + scale: float = None, + ) -> None: + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperatureH = temperatureH + self.temperatureW = temperatureW + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor) -> torch.Tensor: + """Forward function. + + Args: + tensor_list (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W) + """ + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats) + pos_x = x_embed[:, :, :, None] / dim_tx + + dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats) + pos_y = y_embed[:, :, :, None] / dim_ty + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + + return pos + + +@POS_EMBEDDINGS.register_module() +class PositionEmbeddingLearned(nn.Module): + """Absolute pos embedding, learned. + + Args: + num_pos_feats (int): The channel dimension of positional embeddings. + num_row (int): The number of rows of the input feature map. + num_col (int): The number of columns of the input feature map. + """ + + def __init__( + self, num_row: int = 50, num_col: int = 50, num_pos_feats: int = 256 + ) -> None: + super().__init__() + self.row_embed = nn.Embedding(num_row, num_pos_feats) + self.col_embed = nn.Embedding(num_col, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor) -> torch.Tensor: + """Forward function. + + Args: + tensor_list (NestedTensor): NestedTensor wrapping the input tensor. + + Returns: + torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W) + """ + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSineHW( + N_steps, + temperatureH=args.pe_temperatureH, + temperatureW=args.pe_temperatureW, + normalize=True, + ) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding + + +def clean_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k[:7] == "module.": + k = k[7:] # remove `module.` + new_state_dict[k] = v + return new_state_dict + + +def get_activation_fn(activation: str, d_model: int = 256, batch_dim: int = 0): + """Return an activation function given a string + + Args: + activation (str): activation function name + d_model (int, optional): d_model. Defaults to 256. + batch_dim (int, optional): batch dimension. Defaults to 0. + + Returns: + F: activation function + """ + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + if activation == "prelu": + return nn.PReLU() + if activation == "selu": + return F.selu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module: nn.Module, N: int, layer_share: bool = False): + """Copy module N times + + Args: + module (nn.Module): module to copy + N (int): number of copies + layer_share (bool, optional): share the same layer. If true, the modules will + share the same memory. Defaults to False. + """ + if layer_share: + return nn.ModuleList([module for _ in range(N)]) + else: + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def inverse_sigmoid(x, eps=1e-3): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def gen_sineembed_for_position(pos_tensor): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / 128) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 + ).flatten(2) + pos_y = torch.stack( + (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 + ).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 + ).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 + ).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +def get_sine_pos_embed( + pos_tensor: torch.Tensor, + num_pos_feats: int = 128, + temperature: int = 10000, + exchange_xy: bool = True, +): + """generate sine position embedding from a position tensor + Args: + pos_tensor (torch.Tensor): shape: [..., n]. + num_pos_feats (int): projected shape for each float in the tensor. + temperature (int): temperature in the sine/cosine function. + exchange_xy (bool, optional): exchange pos x and pos y. \ + For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True. + Returns: + pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. + """ + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats + ) + + def sine_func(x: torch.Tensor): + sin_x = x * scale / dim_t + sin_x = torch.stack( + (sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3 + ).flatten(2) + return sin_x + + pos_res = [ + sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1) + ] + if exchange_xy: + pos_res[0], pos_res[1] = pos_res[1], pos_res[0] + pos_res = torch.cat(pos_res, dim=-1) + return pos_res + + +def gen_encoder_output_proposals( + memory: torch.Tensor, + memory_padding_mask: torch.Tensor, + spatial_shapes: torch.Tensor, + learnedwh=None, +): + """ + Input: + - memory: bs, \sum{hw}, d_model + - memory_padding_mask: bs, \sum{hw} + - spatial_shapes: nlevel, 2 + - learnedwh: 2 + Output: + - output_memory: bs, \sum{hw}, d_model + - output_proposals: bs, \sum{hw}, 4 + """ + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view( + N_, H_, W_, 1 + ) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view( + N_, 1, 1, 2 + ) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + + if learnedwh is not None: + wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl) + else: + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += H_ * W_ + + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ( + (output_proposals > 0.01) & (output_proposals < 0.99) + ).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float("inf") + ) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float("inf") + ) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0) + ) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + + return output_memory, output_proposals diff --git a/detect_tools/upn/ops/functions/__init__.py b/detect_tools/upn/ops/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2197bda3199aa32cafc5b9d396479609853dd2 --- /dev/null +++ b/detect_tools/upn/ops/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/detect_tools/upn/ops/functions/ms_deform_attn_func.py b/detect_tools/upn/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000000000000000000000000000000000000..8c5df8cf5d23aca963eec6c1133c180b37289607 --- /dev/null +++ b/detect_tools/upn/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/detect_tools/upn/ops/modules/__init__.py b/detect_tools/upn/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f82cb1ad9d634a87b54ba6a71b58a230bcade5fe --- /dev/null +++ b/detect_tools/upn/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/detect_tools/upn/ops/modules/ms_deform_attn.py b/detect_tools/upn/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f11139a25c2a8f59076088a6042bac03773941 --- /dev/null +++ b/detect_tools/upn/ops/modules/ms_deform_attn.py @@ -0,0 +1,204 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math, os + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +try: + from ..functions import MSDeformAttnFunction +except: + warnings.warn("Failed to import MSDeformAttnFunction.") + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)) + ) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__( + self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False + ): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError( + "d_model must be divisible by n_heads, but got {} and {}".format( + d_model, n_heads + ) + ) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self.use_4D_normalizer = use_4D_normalizer + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * ( + 2.0 * math.pi / self.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points + ) + # N, Len_q, n_heads, n_levels, n_points, 2 + + # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': + # import ipdb; ipdb.set_trace() + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + if self.use_4D_normalizer: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / offset_normalizer[None, None, None, :, None, :] + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.n_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + + # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': + # import ipdb; ipdb.set_trace() + + # for amp + if value.dtype == torch.float16: + # for mixed precision + output = MSDeformAttnFunction.apply( + value.to(torch.float32), + input_spatial_shapes, + input_level_start_index, + sampling_locations.to(torch.float32), + attention_weights, + self.im2col_step, + ) + output = output.to(torch.float16) + output = self.output_proj(output) + return output + + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py b/detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py new file mode 100644 index 0000000000000000000000000000000000000000..03c49029f1ed963a22a7e6d18b5eaa916dea0f36 --- /dev/null +++ b/detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py @@ -0,0 +1,130 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math, os + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +try: + from ..functions import MSDeformAttnFunction +except: + warnings.warn('Failed to import MSDeformAttnFunction.') + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self.use_4D_normalizer = use_4D_normalizer + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param key (N, 1, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + + # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': + # import ipdb; ipdb.set_trace() + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + if self.use_4D_normalizer: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/detect_tools/upn/ops/setup.py b/detect_tools/upn/ops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..049f9232ba48996518ccd76df4cf39c8ed14791e --- /dev/null +++ b/detect_tools/upn/ops/setup.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + # import ipdb; ipdb.set_trace() + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp b/detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1bf854de1f3860d20b6fef5c1a17817c268e70a --- /dev/null +++ b/detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h b/detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..81b7b58a3d9502bbb684dc84687a526dedf94cae --- /dev/null +++ b/detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu b/detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..059e2c44410b38f165c0d255d6bc50d4eb7af6e6 --- /dev/null +++ b/detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data_ptr()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data_ptr() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h b/detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..c7ae53f99c820ce6193b608ad344550348a0b42c --- /dev/null +++ b/detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh b/detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6bc2acb7aea0eab2e9e91e769a16861e1652c284 --- /dev/null +++ b/detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/detect_tools/upn/ops/src/ms_deform_attn.h b/detect_tools/upn/ops/src/ms_deform_attn.h new file mode 100644 index 0000000000000000000000000000000000000000..7851c86a090b5aabefc5fa60de1d97f0e4130c7b --- /dev/null +++ b/detect_tools/upn/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/detect_tools/upn/ops/src/vision.cpp b/detect_tools/upn/ops/src/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2201f63a51dca16d0b31148ed2c9e8e47ec15bdc --- /dev/null +++ b/detect_tools/upn/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/detect_tools/upn/ops/test.py b/detect_tools/upn/ops/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbf6d5547d131f01a8c5c28b76557bd27a9334b --- /dev/null +++ b/detect_tools/upn/ops/test.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/detect_tools/upn/transforms/transform.py b/detect_tools/upn/transforms/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb5b11a70c1fe2ea1825d257ed9e121e8e17008 --- /dev/null +++ b/detect_tools/upn/transforms/transform.py @@ -0,0 +1,142 @@ +import random +import torch +import torchvision.transforms.functional as F + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size) + ) + ratio_width, ratio_height = ratios + + target = target.copy() + if "exampler_box" in target: + boxes = target["exampler_box"] + if isinstance(boxes, torch.Tensor): + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["exampler_box"] = scaled_boxes + elif isinstance(boxes, dict): + for k, v in boxes.items(): + scaled_boxes = v * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["exampler_box"][k] = scaled_boxes + + if "demo_pos_exampler_box" in target: + boxes = target["demo_pos_exampler_box"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["demo_pos_exampler_box"] = scaled_boxes + + if "demo_neg_exampler_box" in target: + boxes = target["demo_neg_exampler_box"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["demo_neg_exampler_box"] = scaled_boxes + + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + return rescaled_image, target + + +class RandomResize(object): + + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class ToTensor(object): + + def __call__(self, img, target): + return F.to_tensor(img), target + + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + return image, target + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2895cea4aecc2fc450b686e2f4821c2f000b9fe5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +torch==2.6.0 +torchvision==0.21.0 +transformers==4.50.1 +timm==1.0.9 +accelerate==1.4.0 +gradio +mmengine==0.8.2 +einops +flash-attn +scikit-image \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..5482f8d171bb4e5f74bc648c27a8ca3869346091 --- /dev/null +++ b/run.sh @@ -0,0 +1,23 @@ +echo "--- install dependencies ---" +pip install -r requirements.txt + +# 2. 检查基础包是否安装成功 +if [ $? -ne 0 ]; then + echo "install dependencies failed, exit." + exit 1 +fi + +echo "--- install dependencies successfully ---" +echo "--- compile and install local 'ops' package ---" + +pip install --no-build-isolation -e ./VLM-FO1/detect_tools/upn/ops + +if [ $? -ne 0 ]; then + echo "compile and install local 'ops' package failed, exit." + exit 1 +fi + +echo "--- compile and install local 'ops' package successfully ---" +echo "--- launch Gradio application ---" + +python demo/gradio_demo.py \ No newline at end of file diff --git a/vlm_fo1/__init__.py b/vlm_fo1/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/vlm_fo1/__init__.py @@ -0,0 +1 @@ + diff --git a/vlm_fo1/constants.py b/vlm_fo1/constants.py new file mode 100755 index 0000000000000000000000000000000000000000..ae6734f1b925109d6c03e69c322849c6aeb34702 --- /dev/null +++ b/vlm_fo1/constants.py @@ -0,0 +1,29 @@ +LOGDIR = "." + +global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 #151656 #151655 #-200 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + +# For Qwen2_5_VL +QWEN2_5_VL_IMAGE_TOKEN = "<|image_pad|>" +QWEN2_5_VL_IMAGE_TOKEN_INDEX = 151655 + +# For regions +DEFAULT_REGION_TOKEN = ">" +DEFAULT_REGION_FEATURE_TOKEN = "" +DEFAULT_REGION_INDEX = -300 #151654 #151654 #-300 + +# For Grounding +DEFAULT_GROUNDING_START = "" +DEFAULT_GROUNDING_END = "" +DEFAULT_GROUNDING_OBJECTS_START = "" +DEFAULT_GROUNDING_OBJECTS_END = "" + +# For Think +DEFAULT_THINK_START = "" +DEFAULT_THINK_END = "" diff --git a/vlm_fo1/mm_utils.py b/vlm_fo1/mm_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..a94ae047ba3475dbcfad9c6bb65220a7e29d3d9a --- /dev/null +++ b/vlm_fo1/mm_utils.py @@ -0,0 +1,658 @@ +from PIL import Image +from PIL import ImageDraw, ImageOps +from io import BytesIO +import base64 +import re +import torch +from transformers import StoppingCriteria +from vlm_fo1.constants import IMAGE_TOKEN_INDEX, DEFAULT_REGION_INDEX +import requests +from vlm_fo1.constants import ( + IMAGE_TOKEN_INDEX, + DEFAULT_IMAGE_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IM_END_TOKEN, + IGNORE_INDEX, + DEFAULT_REGION_TOKEN, + DEFAULT_REGION_FEATURE_TOKEN +) +import torch +from transformers import TextStreamer +import random +import re +from typing import List, Tuple +import io +import base64 + + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + """ + Tokenizes prompts containing or ... special tokens. + + If the prompt uses , , ..., each is replaced with a placeholder index (-200). + If the prompt uses , it is replaced with image_token_index. + + Args: + prompt (str): The prompt potentially containing image tokens. + tokenizer: The tokenizer object. + image_token_index (int): Token id to use when encountering token. + return_tensors (Optional[str]): If 'pt', return a torch tensor. + + Returns: + List[int] or torch.Tensor: The tokenized input with image token indices inserted appropriately. + """ + if "" in prompt: + # Case: prompt contains indexed image tokens like , , etc. + image_token_pattern = re.compile(r"") + prompt_chunks = re.split(r'', prompt) + image_tags = image_token_pattern.findall(prompt) + + input_ids = [] + for i, chunk in enumerate(prompt_chunks): + input_ids.extend(tokenizer(chunk).input_ids) + if i < len(image_tags): + # Insert placeholder where token was. + input_ids.append(-200) + else: + # Case: prompt contains plain tokens. + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + # Helper function to insert a separator token between chunks. + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + # If first chunk starts with token, make sure to keep it only once. + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + # Insert image_token_index between chunks. + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + # Optionally convert output to PyTorch tensor. + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + else: + raise ValueError(f'Unsupported tensor type: {return_tensors}') + + return input_ids + +def tokenizer_image_region_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, region_token_index=DEFAULT_REGION_INDEX, return_tensors=None): + """ + Tokenizes prompts containing both and delimiters, inserting specified token indices. + + Each chunk is split, and within that chunk, locations receive region_token_index. + + Args: + prompt (str): The prompt with and delimiters. + tokenizer: The tokenizer object. + image_token_index (int): Insert this at splits. + region_token_index (int): Insert this at splits. + return_tensors (Optional[str]): If 'pt', return torch tensor. + + Returns: + List[int] or torch.Tensor: The tokenized input with region/image tokens placed. + """ + # Split by tags first. + image_chunks = prompt.split('') + + prompt_chunks = [] + for chunk in image_chunks: + # Split each image chunk by . + obj_chunks = chunk.split('') + # Tokenize each subchunk. + token_chunks = [tokenizer(c).input_ids for c in obj_chunks] + prompt_chunks.append(token_chunks) + + input_ids = [] + offset = 0 + + # If first chunk starts with token, include only once. + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and len(prompt_chunks[0][0]) > 0 and prompt_chunks[0][0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0][0]) + + # Stitch together all chunks with region/image tokens at appropriate locations. + for i, chunk_group in enumerate(prompt_chunks): + if len(chunk_group) > 0: + input_ids.extend(chunk_group[0][offset:]) + for chunk in chunk_group[1:]: + input_ids.append(region_token_index) + input_ids.extend(chunk) + # Insert token except after the last image chunk. + if i < len(prompt_chunks) - 1: + input_ids.append(image_token_index) + # Optionally convert to PyTorch tensor. + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + else: + raise ValueError(f'Unsupported tensor type: {return_tensors}') + + return input_ids + +class KeywordsStoppingCriteria(StoppingCriteria): + """ + Implements custom stopping criteria for generation based on keywords: + If the generated output contains any of the keywords, generation stops. + """ + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + # Remove BOS if present except for single token + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + # Track the generation start length + self.start_len = input_ids.shape[1] + + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + """ + Checks if a keyword exists in the latest generated output ids for a single batch element. + """ + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] + if torch.equal(truncated_output_ids, keyword_id): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + """ + Checks for keywords in each batch item; stops when all have satisfied the keyword condition. + """ + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) + +def load_image(image_file): + """ + Loads an image from a local path, base64 string, URL, or PIL.Image. + + If the input image is smaller than 28x28, it will be resized to at least that size. + + Args: + image_file (str or PIL.Image.Image): Image source. + + Returns: + PIL.Image.Image: Loaded image in RGB mode, at least 28x28 in size. + """ + if isinstance(image_file, Image.Image): + image = image_file.convert("RGB") + # Case: load from URL + elif image_file.startswith("http") or image_file.startswith("https"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + # Case: load from base64-encoded string + elif image_file.startswith("data:image/"): + image = image_file.replace("data:image/jpeg;base64,", "") + image_data = base64.b64decode(image) + image = Image.open(BytesIO(image_data)).convert("RGB") + elif isinstance(image_file, str): + # Case: load from local file path + image = Image.open(image_file).convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(image_file)}") + + # Ensure minimum size 28x28 + if image.width < 28 or image.height < 28: + image = image.resize((max(28, image.width), max(28, image.height))) + return image + +def image_to_base64(img_pil): + """ + Encodes a PIL Image as JPEG in base64 format. + + Args: + img_pil (PIL.Image.Image): Source image. + + Returns: + str: base64-encoded JPEG image string. + """ + with io.BytesIO() as buffer: + img_pil.save(buffer, format="JPEG") + base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') + return base64_image + +def draw_bboxes_and_save( + image: Image.Image, + fo1_bboxes: dict = {}, + detection_bboxes: List[Tuple[int, int, int, int]] = [], + output_path: str = 'output.jpg', + color: str = 'red', + total_color: str = 'green', + width: int = 2 +) -> None: + """ + Draws bounding boxes (both ground-truth/proposed and detection) on a PIL image and saves result. + + Args: + image (PIL.Image.Image): Input PIL image object. + fo1_bboxes (dict): Label -> List[bbox] mapping for annotation bboxes. + detection_bboxes (List[Tuple]): List of detection bounding boxes; each bbox is (x_min, y_min, x_max, y_max). + output_path (str): Path to save the output image. + color (str): Color for fo1_bboxes. + total_color (str): Color for detection_bboxes. + width (int): Rectangle outline width. + + Returns: + None + """ + draw = ImageDraw.Draw(image) + + # Draw detection boxes with `total_color` + for bbox in detection_bboxes: + if len(bbox) != 4: + print(f"Warning: skip the invalid bbox {bbox}") + continue + shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])] + draw.rectangle(shape, outline=total_color, width=width) + + # Draw annotated bboxes with labels and `color` + for bbox_label, bbox_list in fo1_bboxes.items(): + for bbox in bbox_list: + if len(bbox) != 4: + print(f"Warning: skip the invalid bbox {bbox}") + continue + shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])] + draw.rectangle(shape, outline=color, width=width) + draw.text((bbox[0], bbox[1]), bbox_label, fill=color) + + # Save output image (catching common IO exceptions). + try: + image.save(output_path) + print(f"The image has been successfully saved to: {output_path}") + except IOError as e: + print(f"Error: failed to save the image to {output_path}. Reason: {e}") + +def adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w): + """ + Adjusts bounding boxes from original image size to resized image size, compensating for scaling. + + Args: + bbox_list (List[List[float]]): List of original boxes [x1, y1, x2, y2]. + original_h (int): Original image height. + original_w (int): Original image width. + resize_h (int): Resized image height. + resize_w (int): Resized image width. + + Returns: + List[List[float]]: Bounding boxes transformed to resized image coordinates. + """ + output_list = [] + def adjust_bbox_range(bbox, width, height): + # Ensure all coordinates are within the original image border. + x1, y1, x2, y2 = bbox + x1 = max(0, min(width, x1)) + y1 = max(0, min(height, y1)) + x2 = max(0, min(width, x2)) + y2 = max(0, min(height, y2)) + return [x1, y1, x2, y2] + + for bbox in bbox_list: + bbox = adjust_bbox_range(bbox, original_w, original_h) + bbox[0] = bbox[0] * resize_w / original_w + bbox[1] = bbox[1] * resize_h / original_h + bbox[2] = bbox[2] * resize_w / original_w + bbox[3] = bbox[3] * resize_h / original_h + output_list.append(bbox) + return output_list + +def extract_predictions_to_bboxes(prediction: str, bbox_list): + """ + Parse prediction string in the expected format and map each ground label + to its corresponding bounding boxes using bbox_list. + + Args: + prediction (str): Model output string with ...... markup. + bbox_list (List[List[float]]): Full list of predicted or reference bounding boxes. + + Returns: + dict: label -> list of bboxes + """ + label_to_indexes = {} + label_to_bboxes = {} + + match_pattern = r"(.*?)<\/ground>(.*?)<\/objects>" + matches = re.findall(match_pattern, prediction) + + for label_text, indexes in matches: + label_text = label_text.strip() + indexes_tags = re.findall(r"", indexes) + region_indexes = set([int(index.split("")[0]) for index in indexes_tags]) + if label_text not in label_to_indexes: + label_to_indexes[label_text] = region_indexes + else: + label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes + + for label, indexes in label_to_indexes.items(): + label_to_bboxes[label] = [bbox_list[index] for index in indexes] + + return label_to_bboxes + +def extract_predictions_to_indexes(prediction: str): + """ + Parse prediction string, returning label -> set-of-indexes mapping. + + Args: + prediction (str): Model prediction output. + + Returns: + dict: label -> set(int) + """ + label_to_indexes = {} + match_pattern = r"(.*?)<\/ground>(.*?)<\/objects>" + matches = re.findall(match_pattern, prediction) + + for label_text, indexes in matches: + label_text = label_text.strip() + indexes_tags = re.findall(r"", indexes) + region_indexes = set([int(index.split("")[0]) for index in indexes_tags]) + if label_text not in label_to_indexes: + label_to_indexes[label_text] = region_indexes + else: + label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes + + return label_to_indexes + +def resize_shortest_edge_images_and_bboxes( + image_list: List[Image.Image], + bbox_lists: List, + candidate_sizes: List[int] = [], + max_size: int = 2048 + ): + """ + Randomly selects a size for the shortest edge, and proportionally resizes both images and bounding boxes. + + The function maintains the image aspect ratio and ensures that the resized dimensions do not exceed the specified max_size. + Bounding boxes are transformed accordingly. + + Args: + image_list (List[Image.Image]): A list of PIL Image objects. + bbox_lists (List[List[List[float]]]): A list of lists of bounding boxes per image. + candidate_sizes (List[int]): Optional list of sizes to choose the target short edge from. + max_size (int): Maximum allowed long edge after resizing. + + Returns: + Tuple[List[Image.Image], List[List[List[float]]]]: + ([resized_image1, ...], [bbox_list1, ...]) - Possibly shape will match original (see below) + + Raises: + ValueError: on input list length mismatch or emptiness. + """ + bbox_tensor = torch.tensor(bbox_lists) + # Normalize input: wrap bbox_lists into list-of-list, if needed. + if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4: + bbox_lists = [bbox_lists] + + if not image_list or not bbox_lists: + raise ValueError("Input lists cannot be empty.") + if len(image_list) != len(bbox_lists): + raise ValueError("The lengths of the image list and the bounding box list must be the same.") + + # Randomly select short edge size (if given candidate sizes) + if len(candidate_sizes) > 0: + target_size = random.choice(candidate_sizes) + else: + target_size = None + + resized_images = [] + transformed_bbox_lists = [] + + # Process each image and its corresponding bbox list + for img, bboxes in zip(image_list, bbox_lists): + original_width, original_height = img.size + + # Determine scaling factor to bring short edge to target_size + shortest_side = min(original_width, original_height) + if target_size: + scale = target_size / shortest_side + else: + scale = 1.0 + + # Propose new height and width with this scale + new_height, new_width = int(original_height * scale), int(original_width * scale) + + # If resulting long edge exceeds max_size, rescale down so that it fits. + longest_side = max(new_height, new_width) + if longest_side > max_size: + scale = max_size / longest_side + new_height, new_width = int(new_height * scale), int(new_width * scale) + # Ensure images are at least 28x28 (model may expect it) + new_width = max(28, new_width) + new_height = max(28, new_height) + + # Resize image, using BICUBIC for quality if shape changes + if new_width == original_width and new_height == original_height: + resized_img = img + else: + resized_img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + resized_images.append(resized_img) + + # Transform bounding boxes + current_transformed_bboxes = [] + scale_ratio_x = new_width / original_width + scale_ratio_y = new_height / original_height + for bbox in bboxes: + x1, y1, x2, y2 = bbox + new_x1 = x1 * scale_ratio_x + new_y1 = y1 * scale_ratio_y + new_x2 = x2 * scale_ratio_x + new_y2 = y2 * scale_ratio_y + current_transformed_bboxes.append([new_x1, new_y1, new_x2, new_y2]) + transformed_bbox_lists.append(current_transformed_bboxes) + + # If original input was a single image (not list), unpack. + if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4: + return resized_images, transformed_bbox_lists[0] + else: + return resized_images, transformed_bbox_lists + +def make_message_context(tokenizer, message, chat_format="chatml"): + """ + Given a message dict, construct the prompt, tokenized context tokens, image URLs, and bbox_list. + + Handles both standard string 'content' and multi-part (list) content, appropriately placing image/region tokens. + + Args: + tokenizer: tokenizer object + message (dict): Contains role, content, and optionally bbox_list. + chat_format (str): Optionally select chat format (default 'chatml'). + + Returns: + tuple: (inp, context_tokens, image_urls, bbox_list) + """ + image_urls = [] + if chat_format == "chatml": + im_start, im_end = "<|im_start|>", "<|im_end|>" + im_start_tokens = [151644] + im_end_tokens = [151645] + nl_tokens = tokenizer.encode("\n") + role = message["role"] + content = message["content"] + bbox_list = message.get("bbox_list", None) + + if role == "system": + inp = f"{im_start}{role}\n{content}{im_end}\n" + context_tokens = tokenizer.encode( + role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, allowed_special=set()) + context_tokens = im_start_tokens + context_tokens + im_end_tokens + + if role == "user": + if isinstance(content, str): + # Plain string message + inp = f"{im_start}{role}\n{content}{im_end}\n" + context_tokens = tokenizer.encode( + role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, + allowed_special=set()) + context_tokens = im_start_tokens + context_tokens + im_end_tokens + if isinstance(content, list): + # Multi-part message (text and image_url parts, maybe region tokens) + inp = f"{im_start}{role}\n" + image_count = 1 + for message_part in content: + if message_part["type"] == "text": + inp += f"{message_part['text']}" + + if message_part["type"] == "image_url": + # Insert special vision/image tokens, possibly region tokens + inp += DEFAULT_IM_START_TOKEN + '' + DEFAULT_IM_END_TOKEN + '\n' + # If regions exist, add per-region special token. + if bbox_list and len(bbox_list) > 0: + for idx, bbox in enumerate(bbox_list): + inp += DEFAULT_REGION_TOKEN.replace('', str(idx)) + DEFAULT_REGION_FEATURE_TOKEN + inp += '\n' + + image_urls.append(message_part['image_url']['url']) + image_count += 1 + inp += f"{im_end}\n" + + # Choose tokenizer logic based on whether bbox (region) list exists + if bbox_list and len(bbox_list) > 0: + context_tokens = tokenizer_image_region_token(inp, tokenizer) + else: + context_tokens = tokenizer_image_token(inp, tokenizer, image_token_index=IMAGE_TOKEN_INDEX) + return inp, context_tokens, image_urls, bbox_list + +def prepare_inputs(model_name, model, image_processors, tokenizer, messages, device="cuda", max_tokens=512, top_p=1.0, temperature=0.0, do_sample=False): + """ + Fully prepares keyword arguments for model.generate (and compatible API) from messages and model specs. + + Handles prompt assembly, tokenization, image loading/preprocessing, region support, streaming, etc. + Supports specific tweak for Qwen2.5-VL style vision tokens. + + Args: + model_name (str): Model identifier string. + model: Model/config object. + image_processors (tuple): (primary, auxiliary) image processors. + tokenizer: Tokenizer object. + messages (list): Multi-message input list (chat history). + device (str): Target (usually 'cuda' or 'cpu'). + max_tokens, top_p, temperature, do_sample: Standard generation kwargs. + + Returns: + dict: ready-to-use argument dict for model.generate(). + """ + # For Qwen2.5-VL, patch vision special tokens globally. + if 'qwen2.5-vl' in model_name.lower() or 'qwen2_5_vl' in model_name.lower(): + global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + DEFAULT_IM_START_TOKEN = "<|vision_start|>" + DEFAULT_IM_END_TOKEN = "<|vision_end|>" + + primary_image_processor, auxiliary_image_processor = image_processors + + prompt = "" + input_tokens = [] + image_urls = [] + # Compose prompt and accumulate all components from provided messages + for message in messages: + inp, context_tokens, image_urls, bbox_list = make_message_context(tokenizer, message) + prompt += inp + input_tokens.extend(context_tokens) + + # Ensure a system prompt at start, if not already present. + if "system" not in prompt: + system_content = "system\nYou are a helpful assistant." + system_prompt = "<|im_start|>" + system_content + "<|im_end|>" + "\n" + prompt = system_prompt + prompt + system_tokens = [151644] + tokenizer(system_content).input_ids + [151645] + tokenizer("\n").input_ids + input_tokens = system_tokens + input_tokens + + # Ensure prompt ends with assistant's turn. + if not prompt.endswith("<|im_start|>assistant"): + last_assistant_prompt = "<|im_start|>" + "assistant" + "\n" + prompt += last_assistant_prompt + # last_assistant_tokens = [6] + self.tokenizer("assistant\n").input_ids + last_assistant_tokens = [151644] + tokenizer("assistant\n").input_ids + input_tokens.extend(last_assistant_tokens) + + primary_images_tensor = None + auxiliary_images_tensor = None + primary_image_grid_thws = None + if image_urls: + # Load images, resize them, and update bbox_list downstream + images = [load_image(i) for i in image_urls] + # print('original images[0].size:', images[0].size) + images, bbox_list = resize_shortest_edge_images_and_bboxes(images, bbox_list, max_size=2048) + # print('resized images[0].size:', images[0].size) + + # When region-indexed tokens are enabled + if getattr(model.config, 'mm_use_region_index_token', False): + origin_image_size = [image.size for image in images] + aux_images = images.copy() + auxiliary_images_tensor = [auxiliary_image_processor.preprocess(i, return_tensors='pt')['pixel_values'][0].to(device) for i in aux_images] + + if bbox_list and len(bbox_list) > 0: + # Limit number of bbox (for computational constraints, etc.) + bbox_list = bbox_list[:100] + resize_h, resize_w = auxiliary_images_tensor[0].shape[-2:] + original_w, original_h = origin_image_size[0] + # Adjust bbox to match resized images (post pre-processing) + bbox_list = adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w) + bbox_list = [torch.tensor(bbox_list)] + else: + bbox_list = None + else: + auxiliary_images_tensor = None + + # Preprocess primary images for main vision model branch + primary_images = [] + primary_image_grid_thws = [] + for im in images: + processed_data = primary_image_processor.preprocess(im, videos=None, return_tensors="pt") + image_i = processed_data['pixel_values'] + image_grid_thw_i = processed_data['image_grid_thw'] + primary_images.append(image_i) + primary_image_grid_thws.append(image_grid_thw_i) + primary_images_tensor = [image_i.to(device) for image_i in primary_images] + + # For Qwen-style, force specific end-token as stopping criterion + if "qwen" in model_name.lower(): + input_ids = torch.tensor([input_tokens]).to(device) + keywords = ["<|im_end|>"] + + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextStreamer( + tokenizer, skip_prompt=True, skip_special_tokens=True + ) + + # Default: greedy decoding if temperature=0. Else: enable sampling. + if temperature == 0.0: + do_sample = False + else: + do_sample = True + + print("question:================\n", prompt, "\n=================") + # print("input ids:========", input_ids, "========") + generation_kwargs = dict( + inputs=input_ids, + images=primary_images_tensor, + images_aux=auxiliary_images_tensor, + image_grid_thws=primary_image_grid_thws, + bbox_list=bbox_list, + do_sample=do_sample, + temperature=temperature, + max_new_tokens=max_tokens, + streamer=streamer, + top_p=top_p, + use_cache=True, + stopping_criteria=[stopping_criteria], + pad_token_id=tokenizer.pad_token_id + ) + return generation_kwargs + diff --git a/vlm_fo1/model/__init__.py b/vlm_fo1/model/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..de7b6f435ea394dacc44a447abbfa01e2dcdfc09 --- /dev/null +++ b/vlm_fo1/model/__init__.py @@ -0,0 +1 @@ +from .language_model.omchat_qwen2_5_vl import OmChatQwen25VLForCausalLM, OmChatQwen25VLConfig \ No newline at end of file diff --git a/vlm_fo1/model/builder.py b/vlm_fo1/model/builder.py new file mode 100755 index 0000000000000000000000000000000000000000..c1fa7cf8545448372f34799c91f74ac0ed091273 --- /dev/null +++ b/vlm_fo1/model/builder.py @@ -0,0 +1,143 @@ +from transformers import AutoTokenizer +import torch +from vlm_fo1.model import * +from safetensors.torch import load_file +import os + + +def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"): + """ + Loads a pretrained model along with its vision towers (and associated image processors). + This function supports loading in 8bit/4bit precision and explicit device placement. + + Args: + model_path (str): Path to the pretrained model directory. + load_8bit (bool): Whether to load the model in 8bit mode. + load_4bit (bool): Whether to load the model in 4bit mode. + device (str): Device to load model onto, e.g., "cuda" or "cpu". + + Returns: + tuple: (tokenizer, model, image_processor) + """ + kwargs = {"device_map": device} + + # Set model loading parameters for quantization or floating point + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + else: + kwargs['torch_dtype'] = torch.bfloat16 + + # print(model_path) + + # Only proceed for vlm-fo1 models + if 'vlm-fo1' in model_path.lower(): + # Load tokenizer (slow tokenizer enforced) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + # If this is the Qwen2.5-VL variant, load with additional kwargs + if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): + model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + output_loading_info=True, + attn_implementation="flash_attention_2", + **kwargs, + cache_dir='./resources', + ) + # print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}') + # (For other variants of vlm-fo1, model loading detail may need additional condition.) + + if 'vlm-fo1' in model_path.lower(): + # --- Vision Tower Loading --- + # Load the main vision tower weights from model_path if it is not yet loaded + primary_vision_tower = model.get_vision_tower() + if primary_vision_tower and not primary_vision_tower.is_loaded: + primary_vision_tower.load_model(model_path=model_path, is_train=False) + primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype + + # Grab primary image processor from vision tower, if present + if primary_vision_tower: + primary_image_processor = primary_vision_tower.image_processor + + # --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) --- + if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): + try: + aux_image_size = model.config.aux_image_size + except Exception: + # If aux_image_size is missing from config fallback to 768 + aux_image_size = 768 + + aux_image_aspect_ratio = model.config.aux_image_aspect_ratio + aux_vision_tower = model.get_vision_tower_aux() + # Only load if not already loaded + if aux_vision_tower and not aux_vision_tower.is_loaded: + aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio) + aux_vision_tower.to(device=device, dtype=torch.bfloat16) + + # Get auxiliary image processor if there is an aux vision tower + if aux_vision_tower: + aux_image_processor = aux_vision_tower.image_processor + else: + image_processor = None # Set to None if there is no auxiliary vision tower + + # image_processor returned as a tuple of (primary, aux) + image_processor = (primary_image_processor, aux_image_processor) + + # --- Ensure vision_tower and vision_tower_aux are loaded with weights from model_path --- + # if 'vlm-fo1' in model_path.lower(): + # print(f"Loading weights from {model_path} to ensure vision_tower uses the correct weights...") # Inform user we are loading vision weights + + # # --- Gather all safetensors files in the model path (for sharded checkpoints) --- + # state_dict = {} + # safetensor_files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')] + + # if safetensor_files: + # for safetensor_file in safetensor_files: + # file_path = os.path.join(model_path, safetensor_file) + # shard_state_dict = load_file(file_path, device="cpu") + # state_dict.update(shard_state_dict) + # else: + # # Fallback to legacy .bin checkpoint if no safetensors found + # state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") + + # # --- Filter out only vision_tower and vision_tower_aux related weights --- + # vision_tower_keys = [k for k in state_dict.keys() if "vision_tower." in k] + # vision_tower_state_dict = {k: state_dict[k] for k in vision_tower_keys if k in state_dict} + + # if vision_tower_keys: + # # print(f"Found {len(vision_tower_keys)} vision_tower weights") + # # Load weights into main vision tower + # if primary_vision_tower and primary_vision_tower.is_loaded: + # # Strips the prefix "model.vision_tower." before loading (for compatibility with submodules) + # missing_keys, unexpected_keys = primary_vision_tower.load_state_dict( + # {k.replace("model.vision_tower.", ""): v for k, v in vision_tower_state_dict.items() + # if k.startswith("model.vision_tower.")}, + # strict=True + # ) + # print(f"vision_tower weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}") + + # # If there is an aux vision tower (Qwen2.5-VL) load its weights as well + # if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): + # if aux_vision_tower and aux_vision_tower.is_loaded: + # vision_tower_aux_keys = [k for k in state_dict.keys() if "vision_tower_aux." in k] + # if vision_tower_aux_keys: + # # print(f"Found {len(vision_tower_aux_keys)} vision_tower_aux weights") + # vision_tower_aux_state_dict = {k: state_dict[k] for k in vision_tower_aux_keys if k in state_dict} + # # Strip "model.vision_tower_aux." prefix before loading for compatibility + # missing_keys, unexpected_keys = aux_vision_tower.load_state_dict( + # {k.replace("model.vision_tower_aux.", ""): v for k, v in vision_tower_aux_state_dict.items() + # if k.startswith("model.vision_tower_aux.")}, + # strict=True + # ) + # print(f"vision_tower_aux weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}") + + # else: + # # If no vision tower weights found, raise an error + # print("No vision_tower weights found") + # raise Exception("No vision_tower weights found") + + # Set model to eval mode and move to correct device before returning + model.eval() + model.to(device=device, dtype=torch.bfloat16) + return tokenizer, model, image_processor diff --git a/vlm_fo1/model/language_model/omchat_qwen2_5_vl.py b/vlm_fo1/model/language_model/omchat_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd367afaa84600863376450efb570c3d942f36d --- /dev/null +++ b/vlm_fo1/model/language_model/omchat_qwen2_5_vl.py @@ -0,0 +1,576 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import Qwen2_5_VLConfig, AutoConfig, AutoModelForCausalLM +from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast +from vlm_fo1.model.multimodal_encoder.qwen2_5_vl_encoder import Qwen2_5_VlVisionTower +from vlm_fo1.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_REGION_INDEX, QWEN2_5_VL_IMAGE_TOKEN, QWEN2_5_VL_IMAGE_TOKEN_INDEX + +from ..omchat_arch import OmChatMetaModel, OmChatMetaForCausalLM + +# Custom config which extends Qwen2_5_VLConfig for OmChat multimodal model +class OmChatQwen25VLConfig(Qwen2_5_VLConfig): + model_type = "omchat_qwen2_5_vl" + rotary_type = "normal_rotary" + multi_scale_im = None + vision_tower_aux = None + +# Core model definition: inherits from OmChat and Qwen multimodal base +class OmChatQwen25VLModel(OmChatMetaModel, Qwen2_5_VLModel): + config_class = OmChatQwen25VLConfig + + def __init__(self, config: Qwen2_5_VLConfig): + super(OmChatQwen25VLModel, self).__init__(config) + +# Main class for multimodal CausalLM +class OmChatQwen25VLForCausalLM(Qwen2_5_VLForConditionalGeneration, OmChatMetaForCausalLM): + config_class = OmChatQwen25VLConfig + + def __init__(self, config, delay_load=True): + # Ensure config has delay_load property + if not hasattr(config, 'delay_load'): + config.delay_load = delay_load + super(Qwen2_5_VLForConditionalGeneration, self).__init__(config) + self.model = OmChatQwen25VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + self.post_init() + + # Encode input images into feature representations + def encode_images(self, images, images_grid_thw=None): + # If vision_tower is Qwen2.5-specific, use its custom forward signature + if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): + image_features = self.get_model().get_vision_tower()(images, images_grid_thw) + image_features, image_grid_thws, multi_level_features = image_features + # If multiple images, handle concatenation + if type(image_features) is list: + # List has items of shape (1, seq_len, dim) + token_length_list = [i.shape[1] for i in image_features] + image_features = torch.cat(image_features, dim=1) # Concatenate to (1, total_seq_len, dim) + else: + image_features = self.get_model().get_vision_tower()(images) + image_grid_thws = None + multi_level_features = None + + image_features = self.get_model().mm_projector(image_features) + + # Split concatenated image features back by original lengths (for multi-image case) + if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): + start = 0 + new_image_features = [] + # Split according to token_length_list + for length in token_length_list: + end = start + length + new_image_features.append(image_features[:, start:end, :].squeeze(0)) + start = end + image_features = new_image_features + + return image_features, image_grid_thws, multi_level_features + + # Encode region regions (bounding boxes) into features, optionally using auxiliary vision tower + def encode_regions(self, images, bbox_list, vt_multi_level_features=None, vt_images_size=None): + aux_image_features_list = self.get_model().get_vision_tower_aux()(images) + region_features = [] + if getattr(self.config, "mm_use_vision_tower_region_feature", False): + image_features_list = vt_multi_level_features + for batch_idx, (image_features, aux_image_features) in enumerate(zip(image_features_list, aux_image_features_list)): + + if getattr(self.config, "mm_use_simpleFPN_for_vt", False): + multilevel_visual_feats = image_features[-1] + else: + multilevel_visual_feats = image_features + multilevel_aux_visual_feats = aux_image_features["image_features"] + boxes = bbox_list[batch_idx] + + # If no boxes provided, use dummy box (covers tiny region) + if boxes is None or len(boxes) == 0: + boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_aux_visual_feats[0].device, dtype=torch.float32) + + boxes = boxes.to(torch.float32).to(multilevel_aux_visual_feats[0].device) + current_image_height, current_image_width = images[batch_idx].shape[-2:] + original_height, original_width = vt_images_size[batch_idx] + # Scale bounding boxes from original image size to processed size + scale_height = original_height / current_image_height + scale_width = original_width / current_image_width + vt_boxes = boxes * torch.tensor([scale_width, scale_height, scale_width, scale_height], device=boxes.device) + + extracted_region_feat = self.get_model().object_vp_extractor( + aux_multi_level_features=multilevel_aux_visual_feats, + vt_multi_level_features=multilevel_visual_feats, + aux_boxes=[boxes], + vt_boxes=[vt_boxes] + ).squeeze(0).to(multilevel_aux_visual_feats[0].dtype) + region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2048] + region_features.append(region_feat) + else: + # Extract region features only from auxiliary vision tower + for batch_idx, image_features in enumerate(aux_image_features_list): + multilevel_visual_feats = image_features["image_features"] + last_feat = image_features["last_feat"] + boxes = bbox_list[batch_idx] + + if boxes is None or len(boxes) == 0: + boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_visual_feats[0].device, dtype=torch.float32) + + multi_level_aux_features = multilevel_visual_feats + boxes = boxes.to(torch.float32).to(multi_level_aux_features[0].device) + extracted_region_feat = self.get_model().object_vp_extractor( + multi_level_aux_features, + [boxes], + ).squeeze(0).to(multi_level_aux_features[0].dtype) + region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2880] + region_features.append(region_feat) + + return region_features + + def get_model(self): + # Getter for model. Used to access backbone/model internals. + return self.model + + # Convert sequence of input_ids/labels/images/boxes to multimodal embedding and associated masks/ids for transformer input. + def prepare_inputs_labels_for_qwen2_5_vl_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux=None, bbox_list=None, image_grid_thws=None + ): + # ========================== Above this line, input parsing and batching ============================= + vision_tower = self.get_vision_tower() + video_tower = self.get_video_tower() + vision_tower_aux = self.get_vision_tower_aux() + # Fast-path for non-multimodal case or first step in generation (i.e. only one token in input) + if (vision_tower is None and video_tower is None) or images is None or input_ids.shape[1] == 1: + if past_key_values is not None and (vision_tower is not None or video_tower is not None) and images is not None and input_ids.shape[1] == 1: + + target_shape = past_key_values[-1][-1].shape[-2] + 1 + attention_mask = torch.cat((attention_mask, torch.ones( + (attention_mask.shape[0], target_shape - attention_mask.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device + )), dim=1) + + position_ids=None + cache_position = torch.tensor([target_shape - 1],device=attention_mask.device) + return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, cache_position + + # Indices for images (3D or 2D tensors) and videos (4D tensors) + image_idx = [idx for idx, img in enumerate(images) if img.ndim == 3 or img.ndim == 2] + is_all_image = len(image_idx) == len(images) + video_idx = [idx for idx, vid in enumerate(images) if vid.ndim == 4] + + # Stack image and video tensors accordingly for mini-batch processing + if isinstance(vision_tower, Qwen2_5_VlVisionTower): + images_minibatch = [images[idx] for idx in image_idx] if len(image_idx) > 0 else [] # list of [c,h,w], can have variable shapes + else: + images_minibatch = torch.stack([images[idx] for idx in image_idx]) if len(image_idx) > 0 else [] # tensor [mini_b, c, h, w] + videos_minibatch = torch.stack([images[idx] for idx in video_idx]) if len(video_idx) > 0 else [] # tensor [mini_b, c, t, h, w] + + # Auxiliary batch for region encoding, if relevant + if vision_tower_aux is not None and images_aux is not None: + images_minibatch_aux = [images_aux[idx].unsqueeze(0) for idx in image_idx] if len(image_idx) > 0 else [] # list of [1, c, h, w] + + # tmp_image_features will be indexed to scatter extracted image/video features into original batch positions + tmp_image_features = [None] * (len(image_idx) + len(video_idx)) + if getattr(images_minibatch, 'ndim', 0) == 4 or (type(images_minibatch) is list and len(images_minibatch) > 0): # batch consists of images, [mini_b, c, h, w] + if vision_tower is not None: + image_features_minibatch, image_grid_thws_minibatch, vt_multi_level_features_minibatch = self.encode_images(images_minibatch, image_grid_thws) # [mini_b, l, c] + else: + image_features_minibatch = torch.randn(1).to(self.device) # dummy feature for video-only training under tuning + + # Map extracted image features back to their places in the original batch + for i, pos in enumerate(image_idx): + tmp_image_features[pos] = image_features_minibatch[i] + + # Handle auxiliary region features if enabled and boxes provided + if vision_tower_aux is not None and bbox_list is not None and len(bbox_list) > 0: + if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): + patch_size = self.get_model().get_vision_tower().config.patch_size + vt_images_size_minibatch = [im_grid_thw[0][-2:]*patch_size for im_grid_thw in image_grid_thws] + region_features = self.encode_regions(images_minibatch_aux, bbox_list, vt_multi_level_features_minibatch, vt_images_size_minibatch) # [mini_b, l, c] + else: + region_features = None + + # Same as above, but for video features if any + if getattr(videos_minibatch, 'ndim', 0) == 5: # batch consists of videos, [mini_b, c, t, h, w] + video_features_minibatch = self.encode_videos(videos_minibatch) # fake list [mini_b, t, l, c] + for i, pos in enumerate(video_idx): + tmp_image_features[pos] = video_features_minibatch[i] + + # Flatten image feature slot list to proper order for current batch + new_tmp = [] + for image in tmp_image_features: + # If multi-image per item, flatten out + if isinstance(image, list): + t = len(image) + for i in range(t): + new_tmp.append(image[i]) + else: + new_tmp.append(image) + image_features = new_tmp + + # =========================== Now, build multimodal input & target sequences ========================= + + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): + raise NotImplementedError + + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + + # Default construction of masks etc. + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # For each batch item, strip padded tokens based on attention_mask + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + # If neither region auxiliary nor bboxes present: process classic image-text input + if vision_tower_aux is None and (bbox_list is None or all(x is None for x in bbox_list)): + new_input_embeds = [] + new_labels = [] + new_input_ids = [] + cur_image_idx = 0 + image_nums_in_batch = [] + + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + image_nums_in_batch.append(num_images) + # If there are no image markers, just get text features + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + new_input_ids.append(cur_input_ids) + cur_image_idx += 1 + continue + + # Split on image token indices: replace them with image features after conversion + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + + cur_new_input_embeds = [] + cur_new_labels = [] + cur_new_input_ids = [] + for i in range(num_images + 1): + # Interleave text and image features + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + cur_new_input_ids.append(cur_input_ids_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx].to(self.device) + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype)) + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + cur_new_input_ids = torch.cat(cur_new_input_ids) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + new_input_ids.append(cur_new_input_ids) + # If region markers or region features enabled in config + else: + new_input_embeds = [] + new_labels = [] + new_input_ids = [] + cur_image_idx = 0 + image_nums_in_batch = [] + + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_region_idx = 0 + # Detect image and region special token counts + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + num_regions = (cur_input_ids == DEFAULT_REGION_INDEX).sum() if DEFAULT_REGION_INDEX in cur_input_ids else 0 + image_nums_in_batch.append(num_images) + + # If no markers, just do text embedding for this item + if num_images == 0 and num_regions == 0: + cur_image_features = image_features[cur_image_idx] + cur_region_features = region_features[cur_region_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_region_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + new_input_ids.append(cur_input_ids) + cur_image_idx += 1 + continue + + # Get all special marker indices (image/region) + image_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + region_indices = torch.where(cur_input_ids == DEFAULT_REGION_INDEX)[0].tolist() if num_regions > 0 else [] + all_special_indices = sorted([-1] + image_indices + region_indices + [cur_input_ids.shape[0]]) + + # Split out plain text chunks between special markers + cur_input_ids_segments = [] + cur_labels = labels[batch_idx] + cur_labels_segments = [] + + for i in range(len(all_special_indices) - 1): + cur_input_ids_segments.append(cur_input_ids[all_special_indices[i]+1:all_special_indices[i+1]]) + cur_labels_segments.append(cur_labels[all_special_indices[i]+1:all_special_indices[i+1]]) + + # Project text ids to word embeddings + split_sizes = [x.shape[0] for x in cur_labels_segments] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_segments)) + if num_regions == 0 and vision_tower_aux is not None and region_features is not None: + cur_region_features = region_features[cur_region_idx] + temp_input_embeds = torch.cat([cur_input_embeds, cur_region_features[0:0]], dim=0) + cur_input_embeds = temp_input_embeds + + cur_input_embeds_segments = torch.split(cur_input_embeds, split_sizes, dim=0) + + # Reassemble text and image/region segments in order + cur_new_input_embeds = [] + cur_new_labels = [] + cur_new_input_ids = [] + + for i in range(len(all_special_indices) - 1): + # Insert current text segment + cur_new_input_embeds.append(cur_input_embeds_segments[i]) + cur_new_labels.append(cur_labels_segments[i]) + cur_new_input_ids.append(cur_input_ids_segments[i]) + # If next is image, insert feature representation + if all_special_indices[i+1] in image_indices: + cur_image_features = image_features[cur_image_idx].to(self.device) + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype)) + + # If next is region token, insert extracted region features + elif all_special_indices[i+1] in region_indices: + cur_region_features = region_features[batch_idx][cur_region_idx].to(self.device).unsqueeze(0) + cur_region_idx += 1 + cur_new_input_embeds.append(cur_region_features) + + cur_new_labels.append(torch.full((cur_region_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + cur_new_input_ids.append(torch.full((cur_region_features.shape[0],), DEFAULT_REGION_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + # Combine for this batch item + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + cur_new_input_ids = torch.cat(cur_new_input_ids) + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + new_input_ids.append(cur_new_input_ids) + # Truncate sequences to maximum model length, if image+region tokens caused overflow + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Pad sequences in the batch to same length; compute batch masks + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + new_input_ids_padded = torch.full((batch_size, max_len), self.config.bos_token_id, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + # Left or right padding as per config; fill padded tensors + for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + # Left pad: add zeros before text tokens/features + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + # Right pad: add zeros after text tokens/features + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + new_input_ids_padded[i, :cur_len] = cur_new_input_ids + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + new_input_ids = new_input_ids_padded + + # Only set new_labels if original labels were not None + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + # Similarly handle provided attention_mask/position_ids overrides + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + # For Qwen2.5 vision towers, use and concatenate image_grid_thws for positional computations + if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): + image_grid_thws = [] + cur_image_idx = 0 + for num_images in image_nums_in_batch: + if num_images == 0: + cur_image_idx += 1 + continue + image_grid_thws += image_grid_thws_minibatch[cur_image_idx:cur_image_idx+num_images] + cur_image_idx += num_images + + if len(image_grid_thws) > 0: + image_grid_thws = torch.cat(image_grid_thws, dim=0) + else: + image_grid_thws = None + + rope_index_kwargs = { + "input_ids": new_input_ids, + "image_grid_thw": image_grid_thws, + "video_grid_thw": None, + "attention_mask": attention_mask, + } + + # Compute new position_ids and rope_deltas for transformer (for rotary embeddings) + position_ids, rope_deltas = self.get_rope_index(**rope_index_kwargs) + cache_position = torch.arange(new_input_embeds.shape[1], device=new_input_embeds.device) + else: + rope_deltas = None + cache_position = None + # Final output is a tuple mimicking HuggingFace prepare_inputs_for_generation return + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, rope_deltas, cache_position + + # Patch forward() of HF CausalLM to allow multimodal embedding with images/regions + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + images: Optional[torch.FloatTensor] = None, + images_aux: Optional[torch.FloatTensor] = None, + bbox_list: Optional[torch.FloatTensor] = None, + image_grid_thws: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + rope_deltas, + cache_position + ) = self.prepare_inputs_labels_for_qwen2_5_vl_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + images_aux, + bbox_list, + image_grid_thws + ) + + if rope_deltas is not None: + self.rope_deltas = rope_deltas + + # Call base CausalLM forward, with possibly replaced multimodal embeddings + out = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + return_dict=return_dict + ) + return out + + # Prepare model input dict for autoregressive generation (for use with generation methods like generate()) + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + images: Optional[torch.FloatTensor] = None, + images_aux: Optional[torch.FloatTensor] = None, + bbox_list: Optional[torch.FloatTensor] = None, + image_grid_thws: Optional[torch.FloatTensor] = None, + **kwargs, + ): + # Wrap parent logic so extra multimodal kwargs are preserved + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + images=images, + images_aux=images_aux, + bbox_list=bbox_list, + image_grid_thws=image_grid_thws, + ) + return model_inputs + +# Register our config and model with HuggingFace transformers registry +AutoConfig.register("omchat_qwen2_5_vl", OmChatQwen25VLConfig) +AutoModelForCausalLM.register(OmChatQwen25VLConfig, OmChatQwen25VLForCausalLM) diff --git a/vlm_fo1/model/multimodal_encoder/__init__.py b/vlm_fo1/model/multimodal_encoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vlm_fo1/model/multimodal_encoder/base_encoder.py b/vlm_fo1/model/multimodal_encoder/base_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..0918a4008736d44bedf1a1791f319ae7b34d6c69 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/base_encoder.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + + +class AbsVisionTower(nn.Module): + @torch.no_grad() + def forward(self, images): + raise NotImplementedError + + @property + def dummy_feature(self): + raise NotImplementedError + + @property + def dtype(self): + raise NotImplementedError + + @property + def device(self): + raise NotImplementedError + + @property + def config(self): + raise NotImplementedError + + + @property + def hidden_size(self): + raise NotImplementedError + + @property + def num_patches(self): + raise NotImplementedError diff --git a/vlm_fo1/model/multimodal_encoder/builder.py b/vlm_fo1/model/multimodal_encoder/builder.py new file mode 100755 index 0000000000000000000000000000000000000000..d99cb4b34493a228b1c4e6011df8a3834cea6a4e --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/builder.py @@ -0,0 +1,38 @@ +# Builders for different vision tower backbones (MM encoder visual modules) +from .qwen2_5_vl_encoder import Qwen2_5_VlVisionTower # Main Qwen2.5 vision tower +from .davit_aux_encoder import DavitVisionTower as DavitVisionTowerAux # Auxiliary DaViT vision tower + +def build_vision_tower(vision_tower_cfg, **kwargs): + """ + Use model config to construct the main vision tower. + + vision_tower_cfg: should have attribute mm_vision_tower + Returns: instance of configured vision backbone + """ + vision_tower_name = getattr(vision_tower_cfg, 'mm_vision_tower', None) + print(vision_tower_cfg) # Debug print of the config being used + + # Check for the Qwen2.5-VL vision model in tower name + if "qwen2.5-vl" in vision_tower_name.lower(): + return Qwen2_5_VlVisionTower(vision_tower_name, args=vision_tower_cfg, **kwargs) + + # Raise a clear error for unknown towers + raise ValueError(f'Unknown vision tower: {vision_tower_name}') + +def build_vision_tower_aux(vision_tower_cfg, **kwargs): + """ + Use model config to construct the auxiliary (helper) vision tower. + + vision_tower_cfg: should have attribute mm_vision_tower_aux + Returns: instance of configured auxiliary vision backbone + """ + vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', None) + # Optionally print config for debugging + # print(vision_tower_cfg) + + # Check for the DaViT auxiliary vision model in tower name + if 'davit' in vision_tower_aux.lower(): + return DavitVisionTowerAux(vision_tower_aux, args=vision_tower_cfg, **kwargs) + + # Raise a clear error if tower type is unknown + raise ValueError(f'Unknown aux vision tower: {vision_tower_aux}') diff --git a/vlm_fo1/model/multimodal_encoder/davit/configs.py b/vlm_fo1/model/multimodal_encoder/davit/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0dc9b07f7c29306111973544b0fed32dcd9613 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/davit/configs.py @@ -0,0 +1,152 @@ + +model_configs = { + "davit-base": { + "depths": [ + 1, + 1, + 9, + 1 + ], + "dim_embed": [ + 128, + 256, + 512, + 1024 + ], + "drop_path_rate": 0.1, + "enable_checkpoint": True, + "image_feature_source": [ + "spatial_avg_pool", + "temporal_avg_pool" + ], + "image_pos_embed": { + "max_pos_embeddings": 50, + "type": "learned_abs_2d" + }, + "num_groups": [ + 4, + 8, + 16, + 32 + ], + "num_heads": [ + 4, + 8, + 16, + 32 + ], + "patch_padding": [ + 3, + 1, + 1, + 1 + ], + "patch_prenorm": [ + False, + True, + True, + True + ], + "patch_size": [ + 7, + 3, + 3, + 3 + ], + "patch_stride": [ + 4, + 2, + 2, + 2 + ], + "projection_dim": 768, + "transformers_version": "4.41.2", + "visual_temporal_embedding": { + "max_temporal_embeddings": 100, + "type": "COSINE" + }, + "window_size": 12 + }, + "davit-large": { + "depths": [ + 1, + 1, + 9, + 1 + ], + "dim_embed": [ + 256, + 512, + 1024, + 2048 + ], + "drop_path_rate": 0.1, + "enable_checkpoint": True, + "image_feature_source": [ + "spatial_avg_pool", + "temporal_avg_pool" + ], + "image_pos_embed": { + "max_pos_embeddings": 50, + "type": "learned_abs_2d" + }, + "num_groups": [ + 8, + 16, + 32, + 64 + ], + "num_heads": [ + 8, + 16, + 32, + 64 + ], + "patch_padding": [ + 3, + 1, + 1, + 1 + ], + "patch_prenorm": [ + False, + True, + True, + True + ], + "patch_size": [ + 7, + 3, + 3, + 3 + ], + "patch_stride": [ + 4, + 2, + 2, + 2 + ], + "projection_dim": 1024, + "transformers_version": "4.41.2", + "visual_temporal_embedding": { + "max_temporal_embeddings": 100, + "type": "COSINE" + }, + "window_size": 12 + } +} + +img_cfg = { + "do_resize": True, + "size": { + "height": 768, + "width":768 + }, + "resample": 3, + "do_center_crop": False, + "do_rescale": True, + "do_normalize": True, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225], + "do_convert_rgb": True +} diff --git a/vlm_fo1/model/multimodal_encoder/davit/configuration_davit.py b/vlm_fo1/model/multimodal_encoder/davit/configuration_davit.py new file mode 100644 index 0000000000000000000000000000000000000000..72d435367ef9347eabdecb00f22ea5f76a90c4a0 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/davit/configuration_davit.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +class DavitConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Florence2VisionModel architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout rate of the drop path layer. + patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]): + The patch size of the image. + patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]): + The patch stride of the image. + patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]): + The patch padding of the image. + patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]): + Whether to apply layer normalization before the patch embedding layer. + enable_checkpoint (`bool`, *optional*, defaults to False): + Whether to enable checkpointing. + dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]): + The dimension of the embedding layer. + num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]): + The number of attention heads. + num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]): + The number of groups. + depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]): + The depth of the model. + window_size (`int`, *optional*, defaults to 12): + The window size of the model. + projection_dim (`int`, *optional*, defaults to 1024): + The dimension of the projection layer. + visual_temporal_embedding (`dict`, *optional*): + The configuration of the visual temporal embedding. + image_pos_embed (`dict`, *optional*): + The configuration of the image position embedding. + image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]): + The source of the image feature. + Example: + + ```python + >>> from transformers import Florence2VisionConfig, Florence2VisionModel + + >>> # Initializing a Florence2 Vision style configuration + >>> configuration = Florence2VisionConfig() + + >>> # Initializing a model (with random weights) + >>> model = Florence2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "florence2_vision" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + drop_path_rate=0.1, + patch_size=[7, 3, 3, 3], + patch_stride=[4, 2, 2, 2], + patch_padding=[3, 1, 1, 1], + patch_prenorm=[False, True, True, True], + enable_checkpoint=False, + dim_embed=[256, 512, 1024, 2048], + num_heads=[8, 16, 32, 64], + num_groups=[8, 16, 32, 64], + depths=[1, 1, 9, 1], + window_size=12, + projection_dim=1024, + visual_temporal_embedding=None, + image_pos_embed=None, + image_feature_source=["spatial_avg_pool", "temporal_avg_pool"], + **kwargs, + ): + self.drop_path_rate = drop_path_rate + self.patch_size = patch_size + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.patch_prenorm = patch_prenorm + self.enable_checkpoint = enable_checkpoint + self.dim_embed = dim_embed + self.num_heads = num_heads + self.num_groups = num_groups + self.depths = depths + self.window_size = window_size + self.projection_dim = projection_dim + self.visual_temporal_embedding = visual_temporal_embedding + self.image_pos_embed = image_pos_embed + self.image_feature_source = image_feature_source + + super().__init__(**kwargs) + + + diff --git a/vlm_fo1/model/multimodal_encoder/davit/image_processing_clip.py b/vlm_fo1/model/multimodal_encoder/davit/image_processing_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..131ceaa07cd055bebfb0fe985378af133cb40074 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/davit/image_processing_clip.py @@ -0,0 +1,370 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for CLIP.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from transformers.utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class CLIPImageProcessor(BaseImageProcessor): + r""" + Constructs a CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resize_mode: str = "squash", + candidate_sizes: List[int] = [384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1280, 1536, 1792, 2048], #[384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792, 1856, 1920, 1984, 2048], #[384, 448, 512, 576, 640, 704, 768, 1024, 1280, 1536, 1792, 2048, 2304, 2560], #[768, 1024, 1280, 1536, 1792, 2048] + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + # for backwards compatibility of KOSMOS-2 + if "use_square_size" in kwargs and kwargs["use_square_size"]: + self.size = {"height": size["shortest_edge"], "width": size["shortest_edge"]} + # Let's remove `use_square_size` (as it is removed from #27690), so the future Kosmos-2 image processors + # won't have this attr. being saved. (otherwise, it will enter this if branch while there is no more + # `shortest_edge` key. + delattr(self, "use_square_size") + + self.resize_mode = resize_mode + self.candidate_sizes = candidate_sizes + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + if self.resize_mode == "dynamic_square": + w, h = image.shape[1], image.shape[0] + area = w * h + + # 找到最接近的目标尺寸 + target_size = self.candidate_sizes[0] + min_diff = float('inf') + + for cur_size in self.candidate_sizes: + target_area = cur_size * cur_size + diff = abs(target_area - area) + if diff < min_diff: + min_diff = diff + target_size = cur_size + size = (target_size, target_size) + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["CLIPImageProcessor"] diff --git a/vlm_fo1/model/multimodal_encoder/davit/modeling_davit.py b/vlm_fo1/model/multimodal_encoder/davit/modeling_davit.py new file mode 100644 index 0000000000000000000000000000000000000000..f326007414cc77e4ca5a240e147557580f6b2407 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/davit/modeling_davit.py @@ -0,0 +1,527 @@ +import math +import torch +import torch.utils.checkpoint +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from collections import OrderedDict +from einops import rearrange +from timm.models.layers import DropPath, trunc_normal_ + +from transformers.utils import ( + logging, +) + +logger = logging.get_logger(__name__) + + + +class MySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class PreNorm(nn.Module): + def __init__(self, norm, fn, drop_path=None): + super().__init__() + self.norm = norm + self.fn = fn + self.drop_path = drop_path + + def forward(self, x, *args, **kwargs): + shortcut = x + if self.norm != None: + x, size = self.fn(self.norm(x), *args, **kwargs) + else: + x, size = self.fn(x, *args, **kwargs) + + if self.drop_path: + x = self.drop_path(x) + + x = shortcut + x + + return x, size + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.net = nn.Sequential(OrderedDict([ + ("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features)) + ])) + + def forward(self, x, size): + return self.net(x), size + + +class DepthWiseConv2d(nn.Module): + def __init__( + self, + dim_in, + kernel_size, + padding, + stride, + bias=True, + ): + super().__init__() + self.dw = nn.Conv2d( + dim_in, dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias=bias + ) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == H * W + + x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + size = (x.size(-2), x.size(-1)) + x = x.flatten(2).transpose(1, 2) + return x, size + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__( + self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None, + pre_norm=True + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding + ) + + dim_norm = in_chans if pre_norm else embed_dim + self.norm = norm_layer(dim_norm) if norm_layer else None + + self.pre_norm = pre_norm + + def forward(self, x, size): + H, W = size + if len(x.size()) == 3: + if self.norm and self.pre_norm: + x = self.norm(x) + x = rearrange( + x, 'b (h w) c -> b c h w', + h=H, w=W + ) + + x = self.proj(x) + + _, _, H, W = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm and not self.pre_norm: + x = self.norm(x) + + return x, (H, W) + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, groups=8, qkv_bias=True): + super().__init__() + + self.groups = groups + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, size): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * (float(N) ** -0.5) + attention = q.transpose(-1, -2) @ k + attention = attention.softmax(dim=-1) + x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x, size + + +class ChannelBlock(nn.Module): + + def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, + drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + conv_at_attn=True, conv_at_ffn=True): + super().__init__() + + drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None + self.channel_attn = PreNorm( + norm_layer(dim), + ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), + drop_path + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), + drop_path + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.channel_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + + return x, size + + +def window_partition(x, window_size: int): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): + B = batch_size + # this will cause onnx conversion failed for dynamic axis, because treated as constant + # int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + def __init__(self, dim, num_heads, window_size, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = float(head_dim) ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, size): + + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition(x, self.window_size) + x = x.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # attn_windows = self.attn(x_windows) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + + # merge windows + x = x.view( + -1, self.window_size, self.window_size, C + ) + x = window_reverse(x, B, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + return x, size + + +class SpatialBlock(nn.Module): + + def __init__(self, dim, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): + super().__init__() + + drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None + self.window_attn = PreNorm( + norm_layer(dim), + WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), + drop_path + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), + drop_path + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.window_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + return x, size + + +class DaViT(nn.Module): + """ DaViT: Dual-Attention Transformer + + Args: + in_chans (int): Number of input image channels. Default: 3. + num_classes (int): Number of classes for classification head. Default: 1000. + patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). + patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). + patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). + patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). + embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). + num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). + num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + drop_path_rate (float): Stochastic depth rate. Default: 0.1. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + enable_checkpoint (bool): If True, enable checkpointing. Default: False. + conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. + conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. + """ + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(1, 1, 3, 1), + patch_size=(7, 2, 2, 2), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 0, 0, 0), + patch_prenorm=(False, False, False, False), + embed_dims=(64, 128, 192, 256), + num_heads=(3, 6, 12, 24), + num_groups=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + enable_checkpoint=False, + conv_at_attn=True, + conv_at_ffn=True + ): + super().__init__() + + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_groups = num_groups + self.num_stages = len(self.embed_dims) + self.enable_checkpoint = enable_checkpoint + assert self.num_stages == len(self.num_heads) == len(self.num_groups) + + num_stages = len(embed_dims) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] + + depth_offset = 0 + convs = [] + blocks = [] + for i in range(num_stages): + conv_embed = ConvEmbed( + patch_size=patch_size[i], + stride=patch_stride[i], + padding=patch_padding[i], + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + norm_layer=norm_layer, + pre_norm=patch_prenorm[i] + ) + convs.append(conv_embed) + + block = MySequential( + *[ + MySequential(OrderedDict([ + ( + 'spatial_block', SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset+j*2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ) + ), + ( + 'channel_block', ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset+j*2+1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ) + ) + ])) for j in range(depths[i]) + ] + ) + blocks.append(block) + depth_offset += depths[i]*2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + # self.norms = norm_layer(self.embed_dims[-1]) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + @property + def dim_out(self): + return self.embed_dims[-1] + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.02) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + + def forward_features_unpool(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + input_size = (x.size(2), x.size(3)) + for conv, block in zip(self.convs, self.blocks): + x, input_size = conv(x, input_size) + if self.enable_checkpoint: + x, input_size = checkpoint.checkpoint(block, x, input_size) + else: + x, input_size = block(x, input_size) + return x + + # def forward_features(self, x): + # x = self.forward_features_unpool(x) + + # # (batch_size, num_tokens, token_dim) + # x = self.avgpool(x.transpose(1, 2)) + # # (batch_size, 1, num_tokens) + # x = torch.flatten(x, 1) + # x = self.norms(x) + + # return x + + def forward_features(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + outs = [] + input_size = (x.size(2), x.size(3)) + for i, (conv, block) in enumerate(zip(self.convs, self.blocks)): + x, input_size = conv(x, input_size) + if self.enable_checkpoint and self.training: + x, input_size = checkpoint.checkpoint(block, x, input_size, use_reentrant=False) + else: + x, input_size = block(x, input_size) + H, W = input_size + x_out = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) + outs.append(x_out) + + # if i in self._out_features: + # norm_layer = getattr(self, f'norm{i}') + # x_out = norm_layer(x) + # H, W = input_size + # x_out = rearrange(x_out, 'b (h w) c -> b c h w', h=H, w=W) + # outs.append(x_out) + + return { + "image_features": outs, + "last_feat": outs[-1], + } + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + + @classmethod + def from_config(cls, config, enable_checkpoint=False): + return cls( + depths=config.depths, + embed_dims=config.dim_embed, + num_heads=config.num_heads, + num_groups=config.num_groups, + patch_size=config.patch_size, + patch_stride=config.patch_stride, + patch_padding=config.patch_padding, + patch_prenorm=config.patch_prenorm, + drop_path_rate=config.drop_path_rate, + window_size=config.window_size, + enable_checkpoint=enable_checkpoint + ) diff --git a/vlm_fo1/model/multimodal_encoder/davit_aux_encoder.py b/vlm_fo1/model/multimodal_encoder/davit_aux_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..63883f3202072e2e5b3e9b09a274e0baf406b905 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/davit_aux_encoder.py @@ -0,0 +1,98 @@ +from vlm_fo1.model.multimodal_encoder.base_encoder import AbsVisionTower +from vlm_fo1.model.multimodal_encoder.davit.configuration_davit import DavitConfig +from vlm_fo1.model.multimodal_encoder.davit.configs import model_configs, img_cfg +from vlm_fo1.model.multimodal_encoder.davit.modeling_davit import DaViT +from vlm_fo1.model.multimodal_encoder.davit.image_processing_clip import CLIPImageProcessor + +# Auxiliary DaViT-based vision tower for multi-modal encoder framework. +# This class manages configuration, processing, and dynamic instantiation of DaViT models. +class DavitVisionTower(AbsVisionTower): + def __init__(self, vision_tower_name, args, delay_load=False, image_size=768, aspect_ratio='squash'): + """ + Args: + vision_tower_name: Identifier string for model variant (usually a file name or config section). + args: Parent MM model/global config (currently ignored). + delay_load: If True, only config is loaded, not the weights/model (for e.g., lazy instantiation). + image_size: Target size to which images are resized (unless aspect_ratio=='dynamic'). + aspect_ratio: Controls how input aspect ratio is handled ('squash', 'dynamic', etc.). + """ + super().__init__() + self.is_loaded = False + self.vision_tower_name = vision_tower_name + self.aspect_ratio = aspect_ratio + self.image_size = image_size + + # In this implementation, training flag is ignored (always uses pretrained weights). + is_train = False + # if not delay_load: + # self.load_model(is_train, self.image_size, self.aspect_ratio) + # else: + # # Only load/prepare configuration (not model weights or modules) + # cfg_dict = model_configs[self.vision_tower_name.split('/')[-1].replace('.pth', '')] + # vision_cfg = DavitConfig.from_dict(cfg_dict) + # vision_cfg.image_size = image_size + # self.cfg_only = vision_cfg + self.load_model(is_train, self.image_size, self.aspect_ratio) + + def load_model(self, is_train=False, image_size=768, aspect_ratio='squash'): + """ + Actually loads the DaViT model (with weights) and its image processor. + Sets up resizing/aspect handling as needed. + """ + cfg_dict = model_configs[self.vision_tower_name.split('/')[-1].replace('.pth', '')] + vision_cfg = DavitConfig.from_dict(cfg_dict) + vision_cfg.image_size = image_size + self.image_tower = DaViT.from_config(config=vision_cfg, enable_checkpoint=True) + self.image_tower.config = vision_cfg + img_cfg['resize_mode'] = aspect_ratio + # If using 'dynamic' aspect ratio, disable resizing for the processor + if aspect_ratio == 'dynamic': # dynamic aspect ratio means no resizing, use the original image size, and the image_size parameter is not used + img_cfg['do_resize'] = False + self.image_processor = CLIPImageProcessor(**img_cfg) + + self.is_loaded = True + + def forward(self, images): + """ + Runs the auxiliary DaViT encoder. + Args: + images: Torch tensor, or list of tensors, of images to encode. + Returns: + List of image feature outputs (typically 4-stage outputs per image). + """ + # If input is a list of images, encode each separately. + if type(images) is list: + image_features = [] + for image in images: + # Forward pass: returns 4-stage outputs; caller must handle downstream selection/merging. + image_features.append(self.image_tower.forward(image.to(device=self.device, dtype=self.dtype))) # this returns 4 stage output + return image_features + else: + # Single image: compute features, return as a length-1 list for consistency. + # image_features = self.image_tower.forward(images.to(device=self.device, dtype=self.dtype)) # this returns 4 stage output + # return [image_features] # return the last layer for now + raise NotImplementedError + + @property + def dtype(self): + # Expose main tensor dtype to external utilities (e.g., for caller to move data to right dtype). + return self.image_tower.convs[0].proj.weight.dtype + + @property + def device(self): + # Expose main parameter device so inputs and other dependent modules use matching device. + return self.image_tower.convs[0].proj.weight.device + + @property + def config(self): + # Get configuration in loaded or 'config only' state + if self.is_loaded: + return self.image_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + # Hidden size: sum of embedding dims (all multi-stage outputs). + return sum(self.image_tower.embed_dims) + diff --git a/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/__init__.py b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/configuration_qwen2_5_vl.py b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/configuration_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..aeaba127cc1f83b94c5c88683e040a1c342aadca --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -0,0 +1,258 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Qwen2_5_VLConfig"] \ No newline at end of file diff --git a/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/modeling_qwen2_5_vl.py b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/modeling_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..e1de8197edf489fbb94e5c48b9985b69ca66666b --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -0,0 +1,2072 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb + +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings, use_reentrant=False + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + # @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + def __init__(self, config: Qwen2_5_VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + use_reentrant=False + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +QWEN2_5_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] diff --git a/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/processing_qwen2_5_vl.py b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/processing_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..0f72e94622542f7934edd9a1f021f113db7a49ae --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl/processing_qwen2_5_vl.py @@ -0,0 +1,239 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen2_5_VLImagesKwargs + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] \ No newline at end of file diff --git a/vlm_fo1/model/multimodal_encoder/qwen2_5_vl_encoder.py b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..664b3c778c63ff9995087b4bbbe1a7e0ed6dcde2 --- /dev/null +++ b/vlm_fo1/model/multimodal_encoder/qwen2_5_vl_encoder.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel +from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor +from torchvision.transforms import ToPILImage + +class VisionFeaturesGather: + """ + Collects and manages intermediate features for multi-level visual representation extraction + (used for region feature/ROIAlign task). Each forward pass (per image) builds up a list of features. + """ + def __init__(self) -> None: + self.features_list = [] + self.grid_thw = None + self.window_index = None + self.merge_size = None + + def reset(self): + """Clear all states before starting a new feature-gathering process.""" + self.features_list.clear() + self.grid_thw = None + self.window_index = None + self.merge_size = None + + def set_params(self, grid_thw, window_index, merge_size): + """Store spatial and merge information for the current image or batch.""" + self.grid_thw = grid_thw + self.window_index = window_index + self.merge_size = merge_size + + def append(self, element): + """Append a set of features (typically per layer in encoder).""" + self.features_list.append(element) + + def extract_multi_level_features(self): + """ + Assemble all gathered multi-level features into canonical tensor forms. + + The goal: for each visual sample, produce a list of region-aligned feature maps + (e.g., multiple stage outputs for downstream region patching/ROIAlign). + + Returns: + List of features, where each element is a list [stage1, stage2, ...] for one image. + """ + # Concatenate all feature tensors along hidden dimension: [seq_len, hidden_size * k] + concat_features = torch.cat(self.features_list, dim=1) + merge_unit = self.merge_size * self.merge_size + seq_len = concat_features.shape[0] + + # Rearrange into [windows, merge_unit, hidden_dim*layers] + concat_features = concat_features.reshape(seq_len // merge_unit, merge_unit, -1) + reverse_indices = torch.argsort(self.window_index) + concat_features = concat_features[reverse_indices, :, :] + concat_features = concat_features.reshape(seq_len, -1) + + # Split features for each image/video by product of grid h and w (per sample) + split_size = (self.grid_thw[:, 1] * self.grid_thw[:, 2]).tolist() + split_features = list(torch.split(concat_features, split_size, dim=0)) + assert len(split_features) == self.grid_thw.shape[0] + for i in range(len(split_features)): + # Recover original grid shape and merge windowing into stages, then split + _, grid_h, grid_w = self.grid_thw[i] + merge_h = grid_h // self.merge_size + merge_w = grid_w // self.merge_size + split_features[i] = split_features[i].reshape(merge_h, merge_w, merge_unit, -1) + split_features[i] = split_features[i].reshape(merge_h, merge_w, self.merge_size, self.merge_size, -1) + split_features[i] = split_features[i].permute(0, 2, 1, 3, 4) + split_features[i] = split_features[i].flatten(start_dim=0, end_dim=-2) + # Split [h, w, dim] into k tensors [1, dim/k, h, w] (for compatibility with multi-stage vision encoding) + hidden_dim = split_features[i].shape[-1] + split_dim = hidden_dim // len(self.features_list) + split_features[i] = split_features[i].reshape(grid_h, grid_w, -1) + split_features[i] = [ + split_features[i][..., j*split_dim:(j+1)*split_dim].permute(2, 0, 1).unsqueeze(0) + for j in range(len(self.features_list)) + ] + + return split_features + +# Global gather object to pass into Qwen2_5_VisionTransformer for monkey-patched feature gathering +GATHER = VisionFeaturesGather() + +# --------------------------------- Monkey Patch --------------------------------------- +def custom_forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Custom forward used with monkey patch to support multi-level feature extraction. + Applies patch embedding, window partition, position embedding, and passes through all blocks. + Optionally collects features at each 'fullatt' block for multi-region support. + + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + Temporal, height, width of each feature sequence. + + Returns: + `torch.Tensor`: Final hidden states after MLP head (merger). + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # FA2 requires that cu_seqlens_q must have dtype int32 + # torch.onnx.export requires that cu_seqlens_q must match grid_thw dtype + # See https://github.com/huggingface/transformers/pull/34852 for more info + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # If monkey-patched feature gather enabled, prepare to collect intermediate features + if hasattr(self, 'vision_features_gather'): + self.vision_features_gather.reset() + self.vision_features_gather.set_params(grid_thw, window_index, self.spatial_merge_size) + + # Forward pass through all transformer blocks; collect intermediate features if needed + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings, use_reentrant=False + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + if hasattr(self, 'vision_features_gather'): + # Capture hidden states at all 'full attention' blocks as multi-level features + if layer_num in self.fullatt_block_indexes: + # This property is set by monkey patching + self.vision_features_gather.append(hidden_states.clone()) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + +def init_vision_features_gather(self, vision_features_gather): + """ + Helper method for monkey patch to inject a VisionFeaturesGather instance into model. + """ + self.vision_features_gather = vision_features_gather + +def replace_qwen_vit_forward(): + """ + Monkey-patch Qwen2_5_VisionTransformer to use custom forward with multi-level feature support. + """ + Qwen2_5_VisionTransformerPretrainedModel.forward = custom_forward + Qwen2_5_VisionTransformerPretrainedModel.init_vision_features_gather = init_vision_features_gather + + +class Qwen2_5_VlVisionTower(nn.Module): + """ + Vision backbone wrapper for Qwen2.5-VL (Vision Transformer). + Handles both standard and region-level (multi-level) encoding with optional monkey patch logic. + """ + def __init__(self, image_tower, args, delay_load=False, min_pixels=56*56, max_pixels=2048*2048): + super().__init__() + + self.is_loaded = False + + self.image_tower_name = image_tower + + # Determine if multi-level region feature is to be enabled (monkey patch required) + self.use_vision_tower_region_feature = getattr(args, 'mm_use_vision_tower_region_feature', False) + if self.use_vision_tower_region_feature: + replace_qwen_vit_forward() # Monkey patch: add multi-level feature extraction logic + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.delay_load = delay_load + print (f"Qwen2_5_VlVisionTower loading_info: delay_load: {delay_load} min_pixels: {min_pixels} max_pixels: {max_pixels}") + + # if not delay_load: + # self.load_model() + # else: + # # Defer actual model loading to support (e.g.) model parallel or delayed download scenarios + # self.cfg_only = args.vision_config + self.cfg_only = args.vision_config + self.load_model(model_path=args.name_or_path) + + def load_model(self, model_path=None, image_size=336, is_train=True): + """ + Actually load Qwen2.5 Vision Tower backbone and processor. + Sets up the image tower and patch feed pipeline. + """ + self.image_tower = Qwen2_5_VisionTransformerPretrainedModel._from_config(self.cfg_only, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16) + # print(f'Qwen2_5_VlVisionTower loading_info: {loading_info}') + + if model_path is not None: + self.image_processor = Qwen2VLImageProcessor.from_pretrained(model_path, min_pixels=self.min_pixels, max_pixels=self.max_pixels) + else: + self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.image_tower_name, min_pixels=self.min_pixels, max_pixels=self.max_pixels) + + if self.use_vision_tower_region_feature: + # Setup gather instance for monkey-patched feature extraction + self.image_tower.init_vision_features_gather(GATHER) + self.is_loaded = True + + def convert_image_format(self, image): + """ + Convert raw image tensor to pre-processed model input tensor and grid shape, using appropriate processor. + Handles PIL conversion and applies preprocessor for Qwen2.5-VL. + """ + pil_image = ToPILImage()(image) + inputs = self.image_processor(images=pil_image, videos=None, return_tensors="pt") + return inputs['pixel_values'], inputs['image_grid_thw'] + + def forward(self, images, image_grid_thws=[]): + """ + Forward pass for a batch (list) of images. + Returns image features, gridTHWs, and optional multi-level features for each input image. + """ + if type(images) is list: + image_features = [] + multi_level_features_list = [] + output_image_grid_thws = [] + + for i, image in enumerate(images): + # If no grid provided, convert and infer via processor + if image_grid_thws is None or len(image_grid_thws) == 0: + image, image_grid_thw = self.convert_image_format(image=image) + else: + image_grid_thw = image_grid_thws[i] + image_forward_out = self.image_tower(image.to(device=self.device, dtype=self.dtype), grid_thw=image_grid_thw.to(device=self.device)) + image_feature = image_forward_out.unsqueeze(0).to(self.dtype) + + image_features.append(image_feature) + output_image_grid_thws.append(image_grid_thw) + + # If region feature mode enabled, collect multi-level features for this image + if self.use_vision_tower_region_feature: + multi_level_features_list.append(self.get_multi_level_features()[0]) + + else: + raise NotImplementedError("Qwen2_5_VlVisionTower only supports list-of-image input") + + return image_features, output_image_grid_thws, multi_level_features_list + + def get_multi_level_features(self): + """ + Get the current (last-processed) multi-level region features from the VisionFeaturesGather helper. + Used in region-feature/ROIAlign branches. + """ + multi_level_features = self.image_tower.vision_features_gather.extract_multi_level_features() + return multi_level_features + + @property + def dummy_feature(self): + """Returns a zero-vector feature, for use as fallback/null visual token.""" + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + """Report vision tower's expected/active tensor dtype (inferred from real weights).""" + return self.image_tower.dtype + + @property + def device(self): + """Report vision tower's tensor device (cuda/cpu) for autoflow/compatibility.""" + return self.image_tower.device + + @property + def config(self): + """Yield config, for both loaded-and-ready and 'config only' modes (delay load etc).""" + if self.is_loaded: + return self.image_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + """Return backbone output hidden size (for proj or post-processing modules).""" + return self.config.out_hidden_size + + @property + def num_patches(self): + """Return number of vision tokens (patches) in processed image.""" + return (self.config.image_size // self.config.patch_size) ** 2 + diff --git a/vlm_fo1/model/multimodal_projector/builder.py b/vlm_fo1/model/multimodal_projector/builder.py new file mode 100755 index 0000000000000000000000000000000000000000..0d0b271480c89d6b6c63d03533c9fe0f36ddfa44 --- /dev/null +++ b/vlm_fo1/model/multimodal_projector/builder.py @@ -0,0 +1,222 @@ +import torch +import torch.nn as nn +import re +from .honeybee import CAbstractor +from functools import partial +import numpy as np +from torch.nn.init import trunc_normal_ +from torch.nn import functional as F +import math + + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": 'identity'} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_hidden_size, config.hidden_size) + if projector_type == "cabstract": + n_query = getattr(config, 'mm_projector_n_query', None) + image_size = getattr(config, 'image_size', None) + if not n_query: + n_query = kwargs.get("mm_projector_n_query",144) + if not image_size: + image_size = kwargs.get("image_size",336) + vokens = int(image_size/14*image_size/14) + print ("n_query",n_query) + print ("image_size",image_size) + print ("vokens",vokens) + + return CAbstractor(vokens, config.mm_hidden_size, config.hidden_size, num_queries=n_query) + + if projector_type == "tokenpacker": + #TokenPacker(hidden_size=config.hidden_size, scale_factor=config.scale_factor) + image_size = kwargs.get("image_size",448) + return TokenPacker(hidden_size=config.hidden_size, mm_hidden_size=config.mm_hidden_size, raw_grid=int(image_size/14)) + + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') + +def build_vision_projector_aux(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_aux_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_region_hidden_size, config.hidden_size) + if projector_type == "cabstract": + n_query = getattr(config, 'mm_projector_n_query', None) + image_size = getattr(config, 'image_size', None) + if not n_query: + n_query = kwargs.get("mm_projector_n_query",144) + if not image_size: + image_size = kwargs.get("image_size",336) + vokens = int(image_size/14*image_size/14) + print ("n_query",n_query) + print ("image_size",image_size) + print ("vokens",vokens) + + return CAbstractor(vokens, config.mm_region_hidden_size, config.hidden_size, num_queries=n_query) + + if projector_type == "tokenpacker": + #TokenPacker(hidden_size=config.hidden_size, scale_factor=config.scale_factor) + image_size = kwargs.get("image_size",448) + return TokenPacker(hidden_size=config.hidden_size, mm_hidden_size=config.mm_region_hidden_size, raw_grid=int(image_size/14)) + + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_region_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') + +class TokenPacker(nn.Module): + def __init__( + self, + raw_grid=32, + embed_dim=1024, + num_heads=1024//128, + hidden_size=4096, + mm_hidden_size=3200, + scale_factor=2, + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ): + super().__init__() + if raw_grid%scale_factor!=0: + raise ValueError("scale_factor must be divisible by grid size") + self.raw_grid = raw_grid + self.grid_size = raw_grid//scale_factor + self.num_queries = self.grid_size ** 2 + self.embed_dim = embed_dim + self.num_heads = num_heads + self.scale_factor = scale_factor + kv_dim = mm_hidden_size + self.q_proj_1 = nn.Linear(kv_dim, embed_dim, bias=False) + + k_modules = [nn.Linear(mm_hidden_size*4, 1024)] + for _ in range(1,2): + k_modules.append(nn.GELU()) + k_modules.append(nn.Linear(1024, 1024)) + self.k_proj_1 = nn.Sequential(*k_modules) + + v_modules = [nn.Linear(mm_hidden_size*4, 1024)] + for _ in range(1,2): + v_modules.append(nn.GELU()) + v_modules.append(nn.Linear(1024, 1024)) + self.v_proj_1 = nn.Sequential(*v_modules) + + self.ln_q_1 = norm_layer(embed_dim) + self.ln_k_1 = norm_layer(embed_dim) + self.ln_v_1 = norm_layer(embed_dim) + + self.clip_attn = nn.MultiheadAttention(embed_dim, num_heads) + + modules = [nn.Linear(1024, hidden_size)] + for _ in range(1, 2): + modules.append(nn.GELU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + self.mlp = nn.Sequential(*modules) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def divide_feature(self, x, kernel_size, token_num, N, c): + h = w = int(token_num**0.5) + + #print (x.shape) + reshape_x = x.reshape(h, w, N, c).reshape(h//kernel_size, kernel_size, w, N, c) + reshape_x = reshape_x.permute(0,2,1,3,4) + reshape_x = reshape_x.reshape(h//kernel_size, w//kernel_size, kernel_size, kernel_size, N, c) + reshape_x = reshape_x.permute(0,1,3,2,4,5).reshape(h//kernel_size, w//kernel_size, kernel_size*kernel_size, N, c) + reshape_x = reshape_x.permute(2,0,1,3,4).reshape(kernel_size*kernel_size, -1, c) + + return reshape_x + + def forward(self, x, attn_mask=None): + + x_multi = x[1] # mulit-level + x = x[0] # original single-level + + key = self.ln_k_1(self.k_proj_1(x_multi)).permute(1, 0, 2) + value = self.ln_v_1(self.v_proj_1(x_multi)).permute(1, 0, 2) + + token_num, N, c = key.shape + + q = F.interpolate(x.reshape(x.shape[0],self.raw_grid,self.raw_grid,-1).float().permute(0,3,1,2), size=(self.grid_size, self.grid_size), mode='bilinear').permute(0,2,3,1) ## fix + q = q.reshape(q.shape[0], -1, q.shape[-1]).to(x.dtype) + + query = self.ln_q_1(self.q_proj_1(q)).permute(1, 0, 2) + + reshape_query = self.divide_feature(query, 1, self.num_queries, N, c) + reshape_key = self.divide_feature(key, self.scale_factor, token_num, N, c) + reshape_value = self.divide_feature(value, self.scale_factor, token_num, N, value.shape[-1]) + + out = self.clip_attn( + reshape_query, + reshape_key, + reshape_value, + attn_mask=attn_mask)[0] + + x = out + x = x.reshape(self.num_queries, N, -1) + x = x.permute(1, 0, 2) + + x = self.mlp(x) + return x + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + diff --git a/vlm_fo1/model/multimodal_projector/honeybee.py b/vlm_fo1/model/multimodal_projector/honeybee.py new file mode 100644 index 0000000000000000000000000000000000000000..dea1f689ccf29c7024ea9b14edca60d847b636de --- /dev/null +++ b/vlm_fo1/model/multimodal_projector/honeybee.py @@ -0,0 +1,142 @@ +from functools import partial +import torch +import torch.nn as nn +from einops import rearrange +from timm.layers import LayerNorm, LayerNorm2d +from timm.models.regnet import RegStage + + +def build_pos_embeds( + pos_emb: bool, num_input_tokens: int, vision_hidden_size: int +): + # pos emb + if pos_emb: + pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size)) + nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02) + else: + pos_emb = None + + return pos_emb + +def build_prenorm(prenorm, encoder_hidden_size): + if prenorm: + prenorm = LayerNorm(encoder_hidden_size) + else: + prenorm = None + return prenorm + + +def build_mlp(depth, hidden_size, output_hidden_size): + layers = [nn.Linear(hidden_size, output_hidden_size)] + for _ in range(1, depth): + layers.append(nn.SiLU()) + layers.append(nn.Linear(output_hidden_size, output_hidden_size)) + return nn.Sequential(*layers) + + +class CAbstractor(nn.Module): + """Base projector class""" + + def __init__( + self, + num_input_tokens: int, + encoder_hidden_size: int, + output_hidden_size: int, + hidden_size: int = 1024, + depth: int = 3, + mlp_depth: int = 2, + num_queries: int = 144, + pos_emb: bool = True, + prenorm: bool = False + ): + super().__init__() + self.num_input_tokens = num_input_tokens + self.encoder_hidden_size = encoder_hidden_size + self.output_hidden_size = output_hidden_size + self.mlp_depth = mlp_depth + self.depth = depth + self.num_queries = num_queries + self.hidden_size = hidden_size + + # pos emb + self.pos_emb = build_pos_embeds(pos_emb, num_input_tokens, encoder_hidden_size) + + self.prenorm = build_prenorm(prenorm, encoder_hidden_size) + + self.build_net() + + def build_net(self): + encoder_hidden_size = self.encoder_hidden_size + hidden_size = self.hidden_size + output_hidden_size = self.output_hidden_size + depth = self.depth + mlp_depth = self.mlp_depth + n_queries = self.num_queries + + assert (n_queries ** 0.5).is_integer(), "n_queries must be square number" + hw = int(n_queries ** 0.5) + + # RegBlock = ResBlock + SE + RegBlock = partial( + RegStage, + stride=1, + dilation=1, + act_layer=nn.SiLU, + norm_layer=LayerNorm2d, + ) + + s1 = RegBlock( + depth, + encoder_hidden_size, + hidden_size, + ) + sampler = nn.AdaptiveAvgPool2d((hw, hw)) + s2 = RegBlock( + depth, + hidden_size, + hidden_size, + ) + + self.net = nn.Sequential(s1, sampler, s2) + self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size) + + def _forward(self, x): + # x: [B, L, dim] + # x = x[:, 1:] # drop cls token and 2d forward @Kyusong, If we output CLS token from vision tower, u can use this + hw = int(x.size(1) ** 0.5) + x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) + x = self.net(x) + x = rearrange(x, "b d h w -> b (h w) d") + x = self.readout(x) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token. + """ + if self.prenorm is not None: + x = self.prenorm(x) + + if self.pos_emb is not None: + x += self.pos_emb + + x = self._forward(x) # (B, L, output_hidden_size) + + return x + + + +if __name__ == "__main__": + B = 2 # batch size + L = 576 # number of input token + H = 1024 # hidden size + + n_query = 256 + output_h = 4096 + + x = torch.FloatTensor(B, L, H) + m = CAbstractor(L, H, output_h, num_queries=n_query) + y = m(x) + print(y.shape) # B, N_Query, output_H diff --git a/vlm_fo1/model/multimodal_visual_prompt_encoder/hybrid_finegrained_region_encoder.py b/vlm_fo1/model/multimodal_visual_prompt_encoder/hybrid_finegrained_region_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..753cf8fbd85530b6efe42c3c3dc9012a518eb658 --- /dev/null +++ b/vlm_fo1/model/multimodal_visual_prompt_encoder/hybrid_finegrained_region_encoder.py @@ -0,0 +1,469 @@ +import torch +import torch.nn as nn +from typing import List, Union +import torch.nn.functional as F +from torchvision.ops import roi_align +import math + +from vlm_fo1.model.multimodal_visual_prompt_encoder.simple_fpn import SimpleFP + + +def generate_2d_position_embedding(height, width, dim, device): + """Generate a 2D positional encoding for a feature map. + + Args: + height (int): Height of the feature map. + width (int): Width of the feature map. + dim (int): Dimensionality of the positional embedding (should match channel count). + device: Torch device on which to allocate tensors. + + Returns: + pos_embed (Tensor): Positional encoding of shape [H, W, dim]. + """ + # Generate grid coordinate vectors of length H and W + y_pos = torch.arange(height, dtype=torch.float32, device=device) + x_pos = torch.arange(width, dtype=torch.float32, device=device) + + # Normalize grid values to [0, 1] + y_pos = y_pos / height + x_pos = x_pos / width + + # Create mesh grid (Y: rows, X: cols) + y_grid, x_grid = torch.meshgrid(y_pos, x_pos, indexing='ij') + + scale = 2 * math.pi + # Calculate positions for sine/cosine encoding + quarter_dim = dim // 4 + dim_t = torch.arange(quarter_dim, dtype=torch.float32, device=device) + dim_t = 10000 ** (2 * (dim_t // 2) / quarter_dim) if quarter_dim > 0 else torch.tensor([1.0], device=device) + + # X direction encoding + x_embed = x_grid.unsqueeze(-1) * scale # [H, W, 1] + pos_x = x_embed / dim_t # [H, W, quarter_dim] + pos_x = torch.stack((pos_x.sin(), pos_x.cos()), dim=-1).flatten(-2) # Alternating sin/cos + + # Y direction encoding + y_embed = y_grid.unsqueeze(-1) * scale # [H, W, 1] + pos_y = y_embed / dim_t # [H, W, quarter_dim] + pos_y = torch.stack((pos_y.sin(), pos_y.cos()), dim=-1).flatten(-2) # Alternating sin/cos + + # Concatenate along the last dimension to make [H, W, dim] + pos_embed = torch.cat([pos_y, pos_x], dim=-1) + + return pos_embed + +def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): + """Generate sine/cosine positional embedding for ROI position(s). + + Args: + pos_tensor (Tensor): Shape [batch_size, N, 4] (format: [cx, cy, w, h] in normalized [0, 1]) + dim_of_pos_feats (int): Output embedding dimensionality (#positional channels). + + Returns: + pos (Tensor): [batch_size, N, dim_of_pos_feats * (2, 4, ...)] + """ + scale = 2 * math.pi + dim_t = torch.arange( + dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device + ) + dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + + # Generate encodings for cx, cy + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 + ).flatten(2) + pos_y = torch.stack( + (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 + ).flatten(2) + if pos_tensor.size(-1) == 2: + # [cx, cy] input + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + # [cx, cy, w, h] input + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 + ).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 + ).flatten(2) + + # Concatenate encodings for [cy, cx, w, h] + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +class HFREModule(nn.Module): + """Hybrid Finegrained Region Encoder (HFREModule). + + Handles multi-level ROI region features, optional position embedding, and feature combination for hybrid visual prompt encoding. + + Args: + roi_output_size (Optional[int]): Output spatial size for ROIAlign. + region_feature_dim (int): The output dimension for region features. + apply_position_embedding (bool): Whether positional embedding is used in region features. + pos_embedding_strategy (str): 'bbox_based', 'feature_map_based', or 'hybrid'. + use_vt_region_feature_only (bool): Only use vision tower features, skip auxiliary. + use_vision_tower_region_feature (bool): Whether to include vision tower region features. + region_feature_combination (str): Combination method: 'concat', 'mean', etc. + use_separate_mlp_for_regions (bool): Whether to MLP project each region type separately. + apply_region_layer_norm (bool): Whether to apply layernorm to region features. + vision_tower_region_feature_dim (int): #channels for vision-tower region feature. + vision_tower_spatial_scale (float): Spatial scale for vision-tower (for roi_align). + use_simpleFPN_for_vt (bool): Whether to use FPN on the vision-tower output. + aux_vision_tower_region_feature_dims (List[int]): Channel dimensions of auxiliary features list. + aux_vision_tower_spatial_scale (float): Spatial scale for auxiliary vision-tower features. + """ + + def __init__( + self, + roi_output_size: int = None, # Output spatial size for ROI region features + region_feature_dim: int = 1024, # Output dimension for final region feature + apply_position_embedding: bool = False, # Whether to apply position embedding + pos_embedding_strategy: str = 'bbox_based', # Strategy: 'bbox_based', 'feature_map_based', 'hybrid' + use_vt_region_feature_only: bool = False, # Use vision tower (VT) region features only + use_vision_tower_region_feature: bool = False,# Use vision tower region features (with others) + region_feature_combination: str = 'concat', # How to combine aux and vt region features + use_separate_mlp_for_regions: bool = False, # MLP-per-region + apply_region_layer_norm: bool = False, # Apply LayerNorm + + # Primary vision tower related + vision_tower_region_feature_dim: int = 5120, # Dim of the VT region feature + vision_tower_spatial_scale: float = 1/14, # Spatial scale of the VT for roi_align + use_simpleFPN_for_vt: bool = False, # Use simpleFPN for vision tower + + # Auxiliary vision tower related + aux_vision_tower_region_feature_dims: List[int] = [256, 512, 1024, 2048], + aux_vision_tower_spatial_scale: float = None, # Scale for aux VT + ): + super(HFREModule, self).__init__() + self.roi_output_size = roi_output_size + self.region_feature_dim = region_feature_dim + self.apply_position_embedding = apply_position_embedding + self.pos_embedding_strategy = pos_embedding_strategy + self.use_vt_region_feature_only = use_vt_region_feature_only + self.use_vision_tower_region_feature = use_vision_tower_region_feature + self.region_feature_combination = region_feature_combination + self.use_separate_mlp_for_regions = use_separate_mlp_for_regions + self.apply_region_layer_norm = apply_region_layer_norm + + self.vision_tower_region_feature_dim = vision_tower_region_feature_dim + self.vision_tower_spatial_scale = vision_tower_spatial_scale + self.use_simpleFPN_for_vt = use_simpleFPN_for_vt + + self.aux_vision_tower_region_feature_dims = aux_vision_tower_region_feature_dims + self.aux_vision_tower_spatial_scale = aux_vision_tower_spatial_scale + + # Print configuration for debugging + # print(f"output_size: {self.roi_output_size} use_vision_tower_region_feature: {self.use_vision_tower_region_feature} vision_tower_region_feature_dim: {self.vision_tower_region_feature_dim} " + # f"apply_position_embedding: {self.apply_position_embedding} region_feature_combination: {self.region_feature_combination} region_feature_dim: {self.region_feature_dim} use_vt_region_feature_only: {self.use_vt_region_feature_only} " + # f"use_simpleFPN_for_vt: {self.use_simpleFPN_for_vt} pos_embedding_strategy: {self.pos_embedding_strategy} " + # f"apply_region_layer_norm: {self.apply_region_layer_norm}") + + # Optional: FPN for the vision tower input if enabled + if self.use_simpleFPN_for_vt: + self.simple_fpn = SimpleFP(out_channels=512, norm="LN", square_pad=0, dim=1280, stride=14) + + # LayerNorm for auxiliary and VT region features, if enabled + if self.apply_region_layer_norm: + if self.use_vision_tower_region_feature: + self.vt_region_norm = nn.LayerNorm(self.vision_tower_region_feature_dim) + if not self.use_vt_region_feature_only: + self.aux_region_norm = nn.LayerNorm(sum(self.aux_vision_tower_region_feature_dims)) + + # Optionally, a projection MLP if using certain combination strategies + if self.use_vision_tower_region_feature and self.region_feature_combination in ['mean', 'mean_sep_pos', 'mean_aux_pos', 'mean_sep_no_vt_pos']: + self.vision_tower_region_feature_projector = nn.Sequential( + nn.Linear(vision_tower_region_feature_dim, region_feature_dim), + nn.GELU(), + nn.Linear(region_feature_dim, region_feature_dim) + ) + + # Two MLP heads if regions are projected separately (for concat mode) + if self.use_vision_tower_region_feature and self.use_separate_mlp_for_regions: + self.vt_region_mlp = nn.Sequential( + nn.Linear(2048, 1024), + nn.GELU(), + nn.Linear(1024, 1024) + ) + self.aux_region_mlp = nn.Sequential( + nn.Linear(2048, 1024), + nn.GELU(), + nn.Linear(1024, 1024) + ) + + def _apply_feature_map_position_embedding(self, features): + """Apply 2D position embedding to each feature map in a feature pyramid, if enabled. + + Args: + features (List[Tensor]): Each is [B, C, H, W] per FPN level. + + Returns: + List[Tensor]: Feature maps with position embedding applied, shape unchanged. + """ + enhanced_features = [] + for level_idx, feature in enumerate(features): + if self.apply_position_embedding and self.pos_embedding_strategy in ['feature_map_based', 'hybrid']: + B, C, H, W = feature.shape + + # Generate position embedding matching channel dimension + pos_embed = generate_2d_position_embedding( + H, W, C, feature.device + ) # [H, W, C] + + # Reshape to [1, C, H, W] and add + pos_embed = pos_embed.permute(2, 0, 1).unsqueeze(0) + feature = feature + pos_embed.to(feature.dtype) + enhanced_features.append(feature) + return enhanced_features + + def extract_vt_region_feature(self, multi_level_features, boxes: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: + """Extract vision-tower region features via ROIAlign over FPN features, with spatial scaling. + + Args: + multi_level_features (List[Tensor]): Per-FPN level features, [B, C, H, W]. + boxes (Union[Tensor, List[Tensor]]): ROI bounding boxes for roi_align. + + Returns: + Tensor: [1, N, C=tower_channels] + """ + if self.use_simpleFPN_for_vt: + # If using FPN for vision tower: apply FPN and select fixed spatial scales (hardcoded stride) + multi_level_features = self.simple_fpn(multi_level_features) + roi_features_per_level = [] + # Hardcoded feature strides for each FPN stage; tweak if arch changes + feature_strides = [3.5, 7, 14, 28] + for level_idx, level_feature in enumerate(multi_level_features): + current_spatial_scale = 1.0 / feature_strides[level_idx] + level_roi_feat = roi_align( + level_feature.float(), + boxes, + output_size=self.roi_output_size, + spatial_scale=current_spatial_scale + ) + # Pool across H,W to get region feature per ROI + level_roi_feat = level_roi_feat.mean(dim=(2, 3)) + roi_features_per_level.append(level_roi_feat) + out_box_feat = torch.cat(roi_features_per_level, dim=1).unsqueeze(0) + else: + # If not using FPN: concatenate all feature levels on channel axis and ROI-align once + concat_multi_level_feature = [] + concat_multi_level_feature = torch.cat(multi_level_features, dim=1) + + out_box_feat = roi_align( + concat_multi_level_feature.float(), + boxes, + output_size=self.roi_output_size, + spatial_scale=self.vision_tower_spatial_scale, + ) + # Pool per ROI for (1, N, C_total) + out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( + 1, out_box_feat.shape[0], out_box_feat.shape[1] + ) + return out_box_feat + + def __call__( + self, + aux_multi_level_features: List[torch.Tensor], + aux_boxes: Union[torch.Tensor, List[torch.Tensor]], + vt_multi_level_features = None, + vt_boxes: Union[torch.Tensor, List[torch.Tensor]] = None, + ) -> torch.Tensor: + """Main forward. Extracts ROI region features with possible hybrid VT/aux, applies position embedding and combines as configured. + + Args: + aux_multi_level_features (List[Tensor]): Auxiliary vision features (e.g., from FPN, [B, C, H, W]). + aux_boxes (Union[Tensor, List[Tensor]]): ROIs in [N, 4] xyxy. + vt_multi_level_features (optional): Vision tower features. + vt_boxes (optional): Vision tower's box coordinates ([N, 4]). + + Returns: + Tensor: Region features of shape [1, N, C], N=#ROIs. + """ + if self.use_vt_region_feature_only: + # Only use VT region features (skip aux completely) + out_box_feat = self.extract_vt_region_feature(vt_multi_level_features, vt_boxes) + + if self.apply_position_embedding: + # Add position embedding to VT region feature + pos_boxes = vt_boxes[0] # (N, 4) + pos_boxes = pos_boxes.to(out_box_feat.dtype) + vt_max_height = max([feature.shape[-2] for feature in vt_multi_level_features]) + vt_max_width = max([feature.shape[-1] for feature in vt_multi_level_features]) + original_img_width = vt_max_width / self.vision_tower_spatial_scale + original_img_height = vt_max_height / self.vision_tower_spatial_scale + # Normalize box coordinates by image size + pos_boxes[:, [0, 2]] = pos_boxes[:, [0, 2]] / original_img_width + pos_boxes[:, [1, 3]] = pos_boxes[:, [1, 3]] / original_img_height + # Convert from (x1, y1, x2, y2) to (cx, cy, w, h) + pos_boxes[:, 2] = pos_boxes[:, 2] - pos_boxes[:, 0] + pos_boxes[:, 3] = pos_boxes[:, 3] - pos_boxes[:, 1] + pos_boxes[:, 0] = pos_boxes[:, 0] + pos_boxes[:, 2] / 2 + pos_boxes[:, 1] = pos_boxes[:, 1] + pos_boxes[:, 3] / 2 + # Add sine/cos position embedding + pos_embed = gen_sineembed_for_position(pos_boxes.unsqueeze(0), self.region_feature_dim // 4) + out_box_feat = out_box_feat + pos_embed + return out_box_feat + + # Otherwise: hybrid mode (aux + possibly VT region features) + aux_boxes[0] = aux_boxes[0].float() + + # Collect all auxiliary features at the same (max) spatial size for channel concat + concat_multi_level_feature = [] + max_height = max([feature.shape[2] for feature in aux_multi_level_features]) + max_width = max([feature.shape[3] for feature in aux_multi_level_features]) + + # Optionally apply 2D position encoding at the feature map level (before concat/roi_align) + if self.pos_embedding_strategy in ['feature_map_based', 'hybrid']: + # Option: compute stride info for each level for debugging/extension + feature_strides = [] + for feature in aux_multi_level_features: + stride = max_height / feature.shape[2] + feature_strides.append(stride) + aux_multi_level_features = self._apply_feature_map_position_embedding( + aux_multi_level_features + ) + + # Interpolate all features to (max_height,max_width), then concat along channel + for level, feature in enumerate(aux_multi_level_features): + if level != 0: + concat_multi_level_feature.append( + F.interpolate( + feature.float(), + size=(max_height, max_width), + mode="bilinear", + align_corners=False, + ) + ) + else: + concat_multi_level_feature.append(feature.float()) + concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) + + # Extract region feature for all boxes using roi_align + out_box_aux_feat = roi_align( + concat_multi_level_feature, + aux_boxes, + output_size=self.roi_output_size, + spatial_scale=self.aux_vision_tower_spatial_scale + ) + + # Pool H,W to get final shape (1, Nbox, C) + out_box_aux_feat = out_box_aux_feat.mean(dim=(2, 3)).reshape( + 1, out_box_aux_feat.shape[0], out_box_aux_feat.shape[1] + ) + + if self.apply_region_layer_norm: + out_box_aux_feat = self.aux_region_norm.float()(out_box_aux_feat) + + if self.use_vision_tower_region_feature: + # If also using vision-tower features + out_box_vt_feat = self.extract_vt_region_feature(vt_multi_level_features, vt_boxes) + if self.apply_region_layer_norm: + out_box_vt_feat = self.vt_region_norm.float()(out_box_vt_feat) + if self.region_feature_combination in ['mean', 'mean_aux_pos']: + # Combine by mean + out_box_feat = (out_box_aux_feat + out_box_vt_feat) / 2 + elif self.region_feature_combination in ['concat', 'concat_aux_pos']: + # Optionally MLP each before concat + if self.use_separate_mlp_for_regions: + original_vt_dtype = out_box_vt_feat.dtype + original_aux_dtype = out_box_aux_feat.dtype + out_box_vt_feat = self.vt_region_mlp(out_box_vt_feat.to(self.vt_region_mlp[0].weight.dtype)).to(original_vt_dtype) + out_box_aux_feat = self.aux_region_mlp(out_box_aux_feat.to(self.aux_region_mlp[0].weight.dtype)).to(original_aux_dtype) + out_box_feat = torch.cat([out_box_aux_feat, out_box_vt_feat], dim=-1) + elif self.region_feature_combination in ['concat_sep_pos', 'mean_sep_pos', 'concat_sep_no_vt_pos', 'mean_sep_no_vt_pos']: + # Compute position embedding separately for aux and vt features + # Use `aux_boxes` for aux and `vt_boxes` for vt + vt_dim = 5120 if self.region_feature_combination == 'concat_sep_pos' else 2880 + + # Aux region: positional embedding using aux_boxes + aux_pos_boxes = aux_boxes[0].to(out_box_aux_feat.dtype) # (N, 4) + aux_original_img_width = max_width / self.aux_vision_tower_spatial_scale + aux_original_img_height = max_height / self.aux_vision_tower_spatial_scale + + aux_pos_boxes[:, [0, 2]] = aux_pos_boxes[:, [0, 2]] / aux_original_img_width + aux_pos_boxes[:, [1, 3]] = aux_pos_boxes[:, [1, 3]] / aux_original_img_height + aux_pos_boxes[:, 2] = aux_pos_boxes[:, 2] - aux_pos_boxes[:, 0] + aux_pos_boxes[:, 3] = aux_pos_boxes[:, 3] - aux_pos_boxes[:, 1] + aux_pos_boxes[:, 0] = aux_pos_boxes[:, 0] + aux_pos_boxes[:, 2] / 2 + aux_pos_boxes[:, 1] = aux_pos_boxes[:, 1] + aux_pos_boxes[:, 3] / 2 + aux_pos_embed = gen_sineembed_for_position( + aux_pos_boxes.unsqueeze(0), 2880 // 4 + ) + out_box_aux_feat = out_box_aux_feat + aux_pos_embed + + # Only apply VT position embedding in these combos: + # For *_no_vt_pos: skip vt feature position embedding + if self.region_feature_combination in ['concat_sep_no_vt_pos', 'mean_sep_no_vt_pos']: + pass + else: + # VT region: positional embedding using vt_boxes + vt_pos_boxes = vt_boxes[0].to(out_box_vt_feat.dtype) # (N, 4) + vt_max_height = max([feature.shape[2] for feature in vt_multi_level_features]) + vt_max_width = max([feature.shape[3] for feature in vt_multi_level_features]) + vt_original_img_width = vt_max_width / self.vision_tower_spatial_scale + vt_original_img_height = vt_max_height / self.vision_tower_spatial_scale + + vt_pos_boxes[:, [0, 2]] = vt_pos_boxes[:, [0, 2]] / vt_original_img_width + vt_pos_boxes[:, [1, 3]] = vt_pos_boxes[:, [1, 3]] / vt_original_img_height + vt_pos_boxes[:, 2] = vt_pos_boxes[:, 2] - vt_pos_boxes[:, 0] + vt_pos_boxes[:, 3] = vt_pos_boxes[:, 3] - vt_pos_boxes[:, 1] + vt_pos_boxes[:, 0] = vt_pos_boxes[:, 0] + vt_pos_boxes[:, 2] / 2 + vt_pos_boxes[:, 1] = vt_pos_boxes[:, 1] + vt_pos_boxes[:, 3] / 2 + vt_pos_embed = gen_sineembed_for_position( + vt_pos_boxes.unsqueeze(0), vt_dim // 4 + ) + out_box_vt_feat = out_box_vt_feat + vt_pos_embed + + # Merge aux and vt region features (by cat or mean) + if self.region_feature_combination in ['concat_sep_pos', 'concat_sep_no_vt_pos']: + out_box_feat = torch.cat([out_box_aux_feat, out_box_vt_feat], dim=-1) + elif self.region_feature_combination in ['mean_sep_pos', 'mean_sep_no_vt_pos']: + out_box_feat = (out_box_aux_feat + out_box_vt_feat) / 2 + + + # If enabled: add single positional embedding (bbox-based, not separate for each region type) + if self.apply_position_embedding and self.region_feature_combination not in ['concat_sep_pos', 'mean_sep_pos', 'concat_sep_no_vt_pos', 'mean_sep_no_vt_pos']: + # Only apply if position embedding strategy matches + apply_bbox_pos_embed = (self.pos_embedding_strategy == 'bbox_based' or self.pos_embedding_strategy == 'hybrid') + + if apply_bbox_pos_embed: + # Use vt_boxes unless not enabled or configured otherwise + if self.use_vision_tower_region_feature and vt_boxes is not None and self.region_feature_combination not in ['concat_aux_pos', 'mean_aux_pos']: + pos_boxes = vt_boxes[0] # (N, 4) + vt_max_height = max([feature.shape[-2] for feature in vt_multi_level_features]) + vt_max_width = max([feature.shape[-1] for feature in vt_multi_level_features]) + vt_spatial_scale = self.vision_tower_spatial_scale + original_img_width = vt_max_width / vt_spatial_scale + original_img_height = vt_max_height / vt_spatial_scale + else: + max_width = max([feature.shape[3] for feature in aux_multi_level_features]) + max_height = max([feature.shape[2] for feature in aux_multi_level_features]) + pos_boxes = aux_boxes[0] # (N, 4) + original_img_width = max_width / self.aux_vision_tower_spatial_scale + original_img_height = max_height / self.aux_vision_tower_spatial_scale + + pos_boxes = pos_boxes.to(out_box_feat.dtype) + pos_boxes[:, [0, 2]] = pos_boxes[:, [0, 2]] / original_img_width + pos_boxes[:, [1, 3]] = pos_boxes[:, [1, 3]] / original_img_height + # Convert box to center format + pos_boxes[:, 2] = pos_boxes[:, 2] - pos_boxes[:, 0] + pos_boxes[:, 3] = pos_boxes[:, 3] - pos_boxes[:, 1] + pos_boxes[:, 0] = pos_boxes[:, 0] + pos_boxes[:, 2] / 2 + pos_boxes[:, 1] = pos_boxes[:, 1] + pos_boxes[:, 3] / 2 + pos_embed = gen_sineembed_for_position( + pos_boxes.unsqueeze(0), self.region_feature_dim // 4 + ) + out_box_feat = out_box_feat + pos_embed + + return out_box_feat \ No newline at end of file diff --git a/vlm_fo1/model/multimodal_visual_prompt_encoder/simple_fpn.py b/vlm_fo1/model/multimodal_visual_prompt_encoder/simple_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..d44b781fb72574fab4be51d6efab6d5b1cecfd69 --- /dev/null +++ b/vlm_fo1/model/multimodal_visual_prompt_encoder/simple_fpn.py @@ -0,0 +1,257 @@ +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F +import warnings + + +class Conv2d(torch.nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + # torchscript does not support SyncBatchNorm yet + # https://github.com/pytorch/pytorch/issues/40507 + # and we skip these codes in torchscript since: + # 1. currently we only support torchscript in evaluation mode + # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or + # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. + if not torch.jit.is_scripting(): + # Dynamo doesn't support context managers yet + is_dynamo_compiling = True + if not is_dynamo_compiling: + with warnings.catch_warnings(record=True): + if x.numel() == 0 and self.training: + # https://github.com/pytorch/pytorch/issues/12013 + assert not isinstance( + self.norm, torch.nn.SyncBatchNorm + ), "SyncBatchNorm does not support empty inputs!" + + x = F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + +class LayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +def get_norm(norm, out_channels): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "LN": lambda channels: LayerNorm(channels), + }[norm] + return norm(out_channels) + +class SimpleFP(nn.Module): + """ + This module implements SimpleFPN in :paper:`vitdet`. + It creates pyramid features built on top of the input feature map. + """ + + def __init__( + self, + out_channels, + scale_factors=[4.0, 2.0, 1.0, 0.5], + top_block=None, + norm="LN", + square_pad=0, + dim=1024, + stride=14, + ): + """ + Args: + out_channels (int): number of channels in the output feature maps. + scale_factors (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + pyramid output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra pyramid levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + norm (str): the normalization to use. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super(SimpleFP, self).__init__() + + self.scale_factors = scale_factors + + strides = [int(stride / scale) for scale in scale_factors] + + self.stages = [] + use_bias = norm == "" + for idx, scale in enumerate(scale_factors): + out_dim = dim + if scale == 4.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + get_norm(norm, dim // 2), + nn.GELU(), + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + layers.extend( + [ + Conv2d( + out_dim, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + ] + ) + layers = nn.Sequential(*layers) + + stage = int(math.log2(strides[idx])) + self.add_module(f"simfp_{stage}", layers) + self.stages.append(layers) + + self.top_block = top_block + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: + mapping from feature map name to pyramid feature map tensor + in high to low resolution order. Returned feature names follow the FPN + convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + features = x + results = [] + + for stage in self.stages: + results.append(stage(features)) + + assert len(self._out_features) == len(results) + return results + + +if __name__ == "__main__": + """ + Test the functionality of SimpleFPN (Feature Pyramid Network). + + The test uses an input tensor of shape (1, 1024, 28, 28). + """ + import torch + + # Generate a dummy input tensor of shape (batch_size=1, channels=1024, height=28, width=28) + test_input = torch.randn(1, 1024, 28, 28) + + # Instantiate the SimpleFP (assumed to be the Feature Pyramid module) + # Note: The arguments below should be checked and adapted according to SimpleFP's actual constructor. + fpn = SimpleFP( + out_channels=256, # Number of output channels for FPN layers + norm="LN", # Normalization type, here using LayerNorm ("LN") + square_pad=0, # Square padding size if needed (here 0 means no padding) + dim=1024, # Number of input channels/features from the backbone + stride=14 # Stride setting, typically related to feature scaling + ) + + # ~~~~~ Model Forward Pass ~~~~~ + # Compute FPN outputs with torch.no_grad() to avoid tracking gradients (for eval/testing) + with torch.no_grad(): + output = fpn(test_input) # Expected: result is a list of feature tensors at different FPN stages + + # ~~~~~ Print Input/Output Information ~~~~~ + print("SimpleFPN Test Results:") + print(f"Input Shape: {test_input.shape}") + print("Output feature maps from each FPN stage:") + + # NOTE: If the output is a list, the features are not named (unlike dict). Adapt print info accordingly. + for idx, feature_map in enumerate(output): + print(f" Output stage {idx}: shape = {feature_map.shape}") + + # If output were a dict with feature names, iterate as below instead: + # for feature_name, feature_map in output.items(): + # print(f" {feature_name}: {feature_map.shape}") + \ No newline at end of file diff --git a/vlm_fo1/model/omchat_arch.py b/vlm_fo1/model/omchat_arch.py new file mode 100755 index 0000000000000000000000000000000000000000..a00bb4a8915a4fa953b7cfed9c4a5e9263e6817b --- /dev/null +++ b/vlm_fo1/model/omchat_arch.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod + +from vlm_fo1.model.multimodal_encoder.builder import build_vision_tower, build_vision_tower_aux +from vlm_fo1.model.multimodal_projector.builder import build_vision_projector, build_vision_projector_aux +from vlm_fo1.model.multimodal_visual_prompt_encoder.hybrid_finegrained_region_encoder import HFREModule + +class OmChatMetaModel: + def __init__(self, config): + super(OmChatMetaModel, self).__init__(config) + # print('----------------------delay_load:', config.delay_load) + if getattr(config, "mm_vision_tower", None) is not None: + self.vision_tower = build_vision_tower(config, delay_load=getattr(config, 'delay_load', True)) + if getattr(config, "mm_vision_tower", None) is not None: + self.mm_projector = build_vision_projector(config) + if getattr(config, "mm_vision_tower_aux", None) is not None: + self.vision_tower_aux = build_vision_tower_aux(config, delay_load=getattr(config, 'delay_load', True)) + self.object_vp_extractor = HFREModule( + roi_output_size=getattr(config, "mm_roi_output_size", 7), + region_feature_dim=config.mm_region_hidden_size, + apply_position_embedding=getattr(config, "mm_apply_position_embedding", True), + pos_embedding_strategy=getattr(config, "mm_pos_embedding_strategy", "bbox_based"), + use_vt_region_feature_only=getattr(config, "mm_use_vt_region_feature_only", False), + use_vision_tower_region_feature=getattr(config, "mm_use_vision_tower_region_feature", False), + region_feature_combination=getattr(config, "mm_region_feature_combination", "concat"), + apply_region_layer_norm=getattr(config, "mm_apply_region_layer_norm", False), + vision_tower_region_feature_dim=self.get_vision_tower().config.hidden_size * 4 if not getattr(config, "mm_use_simpleFPN_for_vt", False) else 2048, + vision_tower_spatial_scale=1/self.get_vision_tower().config.patch_size, + use_simpleFPN_for_vt=getattr(config, "mm_use_simpleFPN_for_vt", False), + aux_vision_tower_spatial_scale=0.25, + aux_vision_tower_region_feature_dims=[256, 512, 1024, 2048], + ) + if getattr(config, "mm_vision_tower_aux", None) is not None: + self.mm_projector_aux = build_vision_projector_aux(config) + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def get_vision_tower_aux(self): + vision_tower_aux = getattr(self, 'vision_tower_aux', None) + if type(vision_tower_aux) is list: + vision_tower_aux = vision_tower_aux[0] + return vision_tower_aux + + def get_video_tower(self): + video_tower = getattr(self, 'video_tower', None) + if type(video_tower) is list: + video_tower = video_tower[0] + return video_tower + + +class OmChatMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_vision_tower_aux(self): + return self.get_model().get_vision_tower_aux() + + def get_video_tower(self): + return self.get_model().get_vision_tower() + + def encode_videos(self, videos): # [mini_b, c, t, h, w] + video_features = self.get_model().get_video_tower()(videos) # [mini_b, t, n, c] + video_features = self.get_model().mm_projector.forward_video(video_features) + return video_features \ No newline at end of file diff --git a/vlm_fo1/task_templates.py b/vlm_fo1/task_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8194279c5cb8060ca4474a14a03f0784305932 --- /dev/null +++ b/vlm_fo1/task_templates.py @@ -0,0 +1,17 @@ +OD_template = "Please detect {} in this image. Answer the question with object indexes." + +OD_Counting_template = "How many {} are there in this image? Count each instance of the target object. Locate them with object indexes and then answer the question with the number of objects." + +REC_template = "Please detect {} in this image. Answer the question with object indexes." + +Region_OCR_template = "Please provide the ocr results of {} in the image." + +Brief_Region_Caption_template = "Provide a brief description for {}." + +Detailed_Region_Caption_template = "Provide a detailed description for {}." + +Grounding_template = "Briefly describe this image and detect all mentioned objects. Answer with grounded object indexes." + +Visual_Prompt_OD_template = "Using the provided object {} as a reference, identify all other objects of the same category in this image. Respond with object indexes." + +Viusal_Region_Reasoning_template = "First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Please give a detailed reasoning process process and provide image regions that can help you answer the question better. {}"