momergul commited on
Commit
ca86ae8
·
verified ·
1 Parent(s): 3c31112

Upload modeling_git.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_git.py +156 -0
modeling_git.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ from transformers import ViTFeatureExtractor, ViTModel, ViTConfig
4
+ from typing import List, Optional, Tuple, Union
5
+ import warnings
6
+ import ipdb
7
+ import os
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from itertools import product
12
+ import numpy as np
13
+ import transformers.models.git.modeling_git as modeling_git
14
+ import transformers.models.vit.modeling_vit as modeling_vit
15
+ from transformers.models.opt.modeling_opt import OPTConfig
16
+ import transformers.models.opt.modeling_opt as hg_opt
17
+ import transformers.models.clip.modeling_clip as modeling_clip
18
+
19
+
20
+ class GitForCausalLM(modeling_git.GitForCausalLM):
21
+ def __init__(self, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+
24
+ del self.output
25
+ self.output = nn.Linear(
26
+ self.config.hidden_size,
27
+ self.config.vocab_size,
28
+ bias=False)
29
+ self.post_init()
30
+
31
+ del self.git.image_encoder
32
+ self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
33
+ dino_cfg = self.git.image_encoder.config
34
+ config = self.git.config
35
+ config.vision_config.hidden_size = dino_cfg.hidden_size
36
+
37
+ del self.git.visual_projection
38
+ self.git.visual_projection = modeling_git.GitProjection(config)
39
+ num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
40
+ self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
41
+
42
+ def forward(
43
+ self,
44
+ input_ids: Optional[torch.Tensor] = None,
45
+ attention_mask: Optional[torch.Tensor] = None,
46
+ position_ids: Optional[torch.Tensor] = None,
47
+ pixel_values: Optional[torch.Tensor] = None,
48
+ head_mask: Optional[torch.Tensor] = None,
49
+ inputs_embeds: Optional[torch.Tensor] = None,
50
+ labels: Optional[torch.Tensor] = None,
51
+ past_key_values: Optional[List[torch.Tensor]] = None,
52
+ use_cache: Optional[bool] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict: Optional[bool] = None,
56
+ **kwargs,
57
+ ) -> Union[Tuple[torch.Tensor], modeling_git.CausalLMOutputWithPast]:
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+ if labels is not None:
60
+ use_cache = False
61
+
62
+ outputs = self.git(
63
+ input_ids,
64
+ attention_mask=attention_mask,
65
+ position_ids=position_ids,
66
+ pixel_values=pixel_values,
67
+ head_mask=head_mask,
68
+ inputs_embeds=inputs_embeds,
69
+ past_key_values=past_key_values,
70
+ use_cache=use_cache,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ sequence_output = outputs[0]
77
+ logits = self.output(sequence_output)
78
+
79
+ loss = None
80
+ if labels is not None:
81
+ # we are doing next-token prediction; shift prediction scores and input ids by one
82
+ if pixel_values is not None:
83
+ num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
84
+ else:
85
+ num_image_tokens = 0
86
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
87
+ labels = labels[:, 1:].contiguous()
88
+ loss_fct = CrossEntropyLoss()
89
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
90
+
91
+ if not return_dict:
92
+ output = (logits,) + outputs[1:]
93
+ return ((loss,) + output) if loss is not None else output
94
+
95
+ return modeling_git.CausalLMOutputWithPast(
96
+ loss=loss,
97
+ logits=logits,
98
+ past_key_values=outputs.past_key_values,
99
+ hidden_states=outputs.hidden_states,
100
+ attentions=outputs.attentions,
101
+ )
102
+
103
+ class GitModel(modeling_git.GitForCausalLM):
104
+ def __init__(self, *args, **kwargs):
105
+ super().__init__(*args, **kwargs)
106
+
107
+ del self.output
108
+ self.post_init()
109
+
110
+ del self.git.image_encoder
111
+ self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
112
+ dino_cfg = self.git.image_encoder.config
113
+ config = self.git.config
114
+ config.vision_config.hidden_size = dino_cfg.hidden_size
115
+
116
+ del self.git.visual_projection
117
+ self.git.visual_projection = modeling_git.GitProjection(config)
118
+ num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
119
+ self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: Optional[torch.Tensor] = None,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ position_ids: Optional[torch.Tensor] = None,
126
+ pixel_values: Optional[torch.Tensor] = None,
127
+ head_mask: Optional[torch.Tensor] = None,
128
+ inputs_embeds: Optional[torch.Tensor] = None,
129
+ labels: Optional[torch.Tensor] = None,
130
+ past_key_values: Optional[List[torch.Tensor]] = None,
131
+ use_cache: Optional[bool] = None,
132
+ output_attentions: Optional[bool] = None,
133
+ output_hidden_states: Optional[bool] = None,
134
+ return_dict: Optional[bool] = None,
135
+ **kwargs,
136
+ ) -> Union[Tuple[torch.Tensor], modeling_git.CausalLMOutputWithPast]:
137
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
138
+ if labels is not None:
139
+ use_cache = False
140
+
141
+ outputs = self.git(
142
+ input_ids,
143
+ attention_mask=attention_mask,
144
+ position_ids=position_ids,
145
+ pixel_values=pixel_values,
146
+ head_mask=head_mask,
147
+ inputs_embeds=inputs_embeds,
148
+ past_key_values=past_key_values,
149
+ use_cache=use_cache,
150
+ output_attentions=output_attentions,
151
+ output_hidden_states=output_hidden_states,
152
+ return_dict=return_dict,
153
+ )
154
+
155
+ return outputs
156
+