| """ |
| Custom GroundingDINO model class for transformers compatibility. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
| class GroundingDINOConfig(PretrainedConfig): |
| """Configuration class for GroundingDINO.""" |
| |
| model_type = "groundingdino" |
| |
| def __init__( |
| self, |
| num_classes=1180, |
| num_queries=900, |
| hidden_dim=256, |
| num_feature_levels=4, |
| nheads=8, |
| enc_layers=6, |
| dec_layers=6, |
| dim_feedforward=2048, |
| dropout=0.0, |
| max_text_len=256, |
| text_encoder_type="bert-base-uncased", |
| backbone="swin_T_224_1k", |
| position_embedding="sine", |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.num_classes = num_classes |
| self.num_queries = num_queries |
| self.hidden_dim = hidden_dim |
| self.num_feature_levels = num_feature_levels |
| self.nheads = nheads |
| self.enc_layers = enc_layers |
| self.dec_layers = dec_layers |
| self.dim_feedforward = dim_feedforward |
| self.dropout = dropout |
| self.max_text_len = max_text_len |
| self.text_encoder_type = text_encoder_type |
| self.backbone = backbone |
| self.position_embedding = position_embedding |
|
|
|
|
| class GroundingDINOModel(PreTrainedModel): |
| """GroundingDINO model for transformers.""" |
| |
| config_class = GroundingDINOConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| |
| |
| self.model = None |
| |
| def forward(self, images, text_prompts=None, return_dict=True): |
| """ |
| Forward pass of the model. |
| |
| Args: |
| images: Input images tensor |
| text_prompts: Text prompts for grounding |
| return_dict: Whether to return a dictionary |
| |
| Returns: |
| Model outputs |
| """ |
| if self.model is None: |
| raise NotImplementedError( |
| "Model architecture not implemented. " |
| "Please use the original GroundingDINO implementation for inference." |
| ) |
| |
| outputs = self.model(images, captions=text_prompts) |
| |
| if return_dict: |
| return { |
| "logits": outputs.get("pred_logits", torch.tensor([])), |
| "boxes": outputs.get("pred_boxes", torch.tensor([])), |
| "last_hidden_state": outputs.get("last_hidden_state", torch.tensor([])) |
| } |
| else: |
| return outputs |
|
|