open_groundingdino / modeling_groundingdino.py
m522t's picture
Upload open_groundingdino model
fb7bd9e verified
"""
Custom GroundingDINO model class for transformers compatibility.
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
class GroundingDINOConfig(PretrainedConfig):
"""Configuration class for GroundingDINO."""
model_type = "groundingdino"
def __init__(
self,
num_classes=1180,
num_queries=900,
hidden_dim=256,
num_feature_levels=4,
nheads=8,
enc_layers=6,
dec_layers=6,
dim_feedforward=2048,
dropout=0.0,
max_text_len=256,
text_encoder_type="bert-base-uncased",
backbone="swin_T_224_1k",
position_embedding="sine",
**kwargs
):
super().__init__(**kwargs)
self.num_classes = num_classes
self.num_queries = num_queries
self.hidden_dim = hidden_dim
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.enc_layers = enc_layers
self.dec_layers = dec_layers
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.max_text_len = max_text_len
self.text_encoder_type = text_encoder_type
self.backbone = backbone
self.position_embedding = position_embedding
class GroundingDINOModel(PreTrainedModel):
"""GroundingDINO model for transformers."""
config_class = GroundingDINOConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# This is a placeholder - in practice, you would load the actual model architecture
# For now, we'll create a simple wrapper
self.model = None
def forward(self, images, text_prompts=None, return_dict=True):
"""
Forward pass of the model.
Args:
images: Input images tensor
text_prompts: Text prompts for grounding
return_dict: Whether to return a dictionary
Returns:
Model outputs
"""
if self.model is None:
raise NotImplementedError(
"Model architecture not implemented. "
"Please use the original GroundingDINO implementation for inference."
)
outputs = self.model(images, captions=text_prompts)
if return_dict:
return {
"logits": outputs.get("pred_logits", torch.tensor([])),
"boxes": outputs.get("pred_boxes", torch.tensor([])),
"last_hidden_state": outputs.get("last_hidden_state", torch.tensor([]))
}
else:
return outputs