momergul commited on
Commit
67f9440
·
verified ·
1 Parent(s): 4be5292

Upload modeling_git.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_git.py +55 -0
modeling_git.py CHANGED
@@ -99,3 +99,58 @@ class GitForCausalLM(modeling_git.GitForCausalLM):
99
  hidden_states=outputs.hidden_states,
100
  attentions=outputs.attentions,
101
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  hidden_states=outputs.hidden_states,
100
  attentions=outputs.attentions,
101
  )
102
+
103
+ class GitModel(modeling_git.GitForCausalLM):
104
+ def __init__(self, *args, **kwargs):
105
+ super().__init__(*args, **kwargs)
106
+
107
+ del self.output
108
+ self.post_init()
109
+
110
+ del self.git.image_encoder
111
+ self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
112
+ dino_cfg = self.git.image_encoder.config
113
+ config = self.git.config
114
+ config.vision_config.hidden_size = dino_cfg.hidden_size
115
+
116
+ del self.git.visual_projection
117
+ self.git.visual_projection = modeling_git.GitProjection(config)
118
+ num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
119
+ self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: Optional[torch.Tensor] = None,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ position_ids: Optional[torch.Tensor] = None,
126
+ pixel_values: Optional[torch.Tensor] = None,
127
+ head_mask: Optional[torch.Tensor] = None,
128
+ inputs_embeds: Optional[torch.Tensor] = None,
129
+ labels: Optional[torch.Tensor] = None,
130
+ past_key_values: Optional[List[torch.Tensor]] = None,
131
+ use_cache: Optional[bool] = None,
132
+ output_attentions: Optional[bool] = None,
133
+ output_hidden_states: Optional[bool] = None,
134
+ return_dict: Optional[bool] = None,
135
+ **kwargs,
136
+ ) -> Union[Tuple[torch.Tensor], modeling_git.CausalLMOutputWithPast]:
137
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
138
+ if labels is not None:
139
+ use_cache = False
140
+
141
+ outputs = self.git(
142
+ input_ids,
143
+ attention_mask=attention_mask,
144
+ position_ids=position_ids,
145
+ pixel_values=pixel_values,
146
+ head_mask=head_mask,
147
+ inputs_embeds=inputs_embeds,
148
+ past_key_values=past_key_values,
149
+ use_cache=use_cache,
150
+ output_attentions=output_attentions,
151
+ output_hidden_states=output_hidden_states,
152
+ return_dict=return_dict,
153
+ )
154
+
155
+ return outputs
156
+