shivavardhineedi commited on
Commit
c96f668
·
verified ·
1 Parent(s): 1a81892

Delete modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +0 -346
modeling_internvl_chat.py DELETED
@@ -1,346 +0,0 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # --------------------------------------------------------
6
- import warnings
7
- from typing import Any, List, Optional, Tuple, Union
8
-
9
- import torch.utils.checkpoint
10
- import transformers
11
- from torch import nn
12
- from torch.nn import CrossEntropyLoss
13
- from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
14
- LlamaTokenizer)
15
- from transformers.modeling_outputs import CausalLMOutputWithPast
16
- from transformers.modeling_utils import PreTrainedModel
17
- from transformers.utils import ModelOutput, logging
18
-
19
- from .configuration_internvl_chat import InternVLChatConfig
20
- from .conversation import get_conv_template
21
- from .modeling_intern_vit import InternVisionModel
22
- from .modeling_internlm2 import InternLM2ForCausalLM
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
-
27
- def version_cmp(v1, v2, op='eq'):
28
- import operator
29
-
30
- from packaging import version
31
- op_func = getattr(operator, op)
32
- return op_func(version.parse(v1), version.parse(v2))
33
-
34
-
35
- class InternVLChatModel(PreTrainedModel):
36
- config_class = InternVLChatConfig
37
- main_input_name = 'pixel_values'
38
- _supports_flash_attn_2 = True
39
- _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer']
40
-
41
- def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
42
- super().__init__(config)
43
-
44
- assert version_cmp(transformers.__version__, '4.36.2', 'ge')
45
- image_size = config.force_image_size or config.vision_config.image_size
46
- patch_size = config.vision_config.patch_size
47
- self.patch_size = patch_size
48
- self.select_layer = config.select_layer
49
- self.template = config.template
50
- self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
51
- self.downsample_ratio = config.downsample_ratio
52
- self.ps_version = config.ps_version
53
-
54
- logger.info(f'num_image_token: {self.num_image_token}')
55
- logger.info(f'ps_version: {self.ps_version}')
56
- if vision_model is not None:
57
- self.vision_model = vision_model
58
- else:
59
- self.vision_model = InternVisionModel(config.vision_config)
60
- if language_model is not None:
61
- self.language_model = language_model
62
- else:
63
- if config.llm_config.architectures[0] == 'LlamaForCausalLM':
64
- self.language_model = LlamaForCausalLM(config.llm_config)
65
- elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
66
- self.language_model = InternLM2ForCausalLM(config.llm_config)
67
- else:
68
- raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
69
-
70
- vit_hidden_size = config.vision_config.hidden_size
71
- llm_hidden_size = config.llm_config.hidden_size
72
-
73
- self.mlp1 = nn.Sequential(
74
- nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
75
- nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
76
- nn.GELU(),
77
- nn.Linear(llm_hidden_size, llm_hidden_size)
78
- )
79
-
80
- self.img_context_token_id = None
81
- self.conv_template = get_conv_template(self.template)
82
- self.system_message = self.conv_template.system_message
83
-
84
- def forward(
85
- self,
86
- pixel_values: torch.FloatTensor,
87
- input_ids: torch.LongTensor = None,
88
- attention_mask: Optional[torch.Tensor] = None,
89
- position_ids: Optional[torch.LongTensor] = None,
90
- image_flags: Optional[torch.LongTensor] = None,
91
- past_key_values: Optional[List[torch.FloatTensor]] = None,
92
- labels: Optional[torch.LongTensor] = None,
93
- use_cache: Optional[bool] = None,
94
- output_attentions: Optional[bool] = None,
95
- output_hidden_states: Optional[bool] = None,
96
- return_dict: Optional[bool] = None,
97
- ) -> Union[Tuple, CausalLMOutputWithPast]:
98
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
99
-
100
- image_flags = image_flags.squeeze(-1)
101
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
102
-
103
- vit_embeds = self.extract_feature(pixel_values)
104
- vit_embeds = vit_embeds[image_flags == 1]
105
- vit_batch_size = pixel_values.shape[0]
106
-
107
- B, N, C = input_embeds.shape
108
- input_embeds = input_embeds.reshape(B * N, C)
109
-
110
- if torch.distributed.get_rank() == 0:
111
- print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
112
-
113
- input_ids = input_ids.reshape(B * N)
114
- selected = (input_ids == self.img_context_token_id)
115
- try:
116
- input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
117
- except Exception as e:
118
- vit_embeds = vit_embeds.reshape(-1, C)
119
- print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
120
- f'vit_embeds.shape={vit_embeds.shape}')
121
- n_token = selected.sum()
122
- input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
123
-
124
- input_embeds = input_embeds.reshape(B, N, C)
125
-
126
- outputs = self.language_model(
127
- inputs_embeds=input_embeds,
128
- attention_mask=attention_mask,
129
- position_ids=position_ids,
130
- past_key_values=past_key_values,
131
- use_cache=use_cache,
132
- output_attentions=output_attentions,
133
- output_hidden_states=output_hidden_states,
134
- return_dict=return_dict,
135
- )
136
- logits = outputs.logits
137
-
138
- loss = None
139
- if labels is not None:
140
- # Shift so that tokens < n predict n
141
- shift_logits = logits[..., :-1, :].contiguous()
142
- shift_labels = labels[..., 1:].contiguous()
143
- # Flatten the tokens
144
- loss_fct = CrossEntropyLoss()
145
- shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
146
- shift_labels = shift_labels.view(-1)
147
- # Enable model parallelism
148
- shift_labels = shift_labels.to(shift_logits.device)
149
- loss = loss_fct(shift_logits, shift_labels)
150
-
151
- if not return_dict:
152
- output = (logits,) + outputs[1:]
153
- return (loss,) + output if loss is not None else output
154
-
155
- return CausalLMOutputWithPast(
156
- loss=loss,
157
- logits=logits,
158
- past_key_values=outputs.past_key_values,
159
- hidden_states=outputs.hidden_states,
160
- attentions=outputs.attentions,
161
- )
162
-
163
- def pixel_shuffle(self, x, scale_factor=0.5):
164
- n, w, h, c = x.size()
165
- # N, W, H, C --> N, W, H * scale, C // scale
166
- x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
167
- # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
168
- x = x.permute(0, 2, 1, 3).contiguous()
169
- # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
170
- x = x.view(n, int(h * scale_factor), int(w * scale_factor),
171
- int(c / (scale_factor * scale_factor)))
172
- if self.ps_version == 'v1':
173
- warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
174
- 'which results in a transposed image.')
175
- else:
176
- x = x.permute(0, 2, 1, 3).contiguous()
177
- return x
178
-
179
- def extract_feature(self, pixel_values):
180
- if self.select_layer == -1:
181
- vit_embeds = self.vision_model(
182
- pixel_values=pixel_values,
183
- output_hidden_states=False,
184
- return_dict=True).last_hidden_state
185
- else:
186
- vit_embeds = self.vision_model(
187
- pixel_values=pixel_values,
188
- output_hidden_states=True,
189
- return_dict=True).hidden_states[self.select_layer]
190
- vit_embeds = vit_embeds[:, 1:, :]
191
-
192
- h = w = int(vit_embeds.shape[1] ** 0.5)
193
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
194
- vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
195
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
196
- vit_embeds = self.mlp1(vit_embeds)
197
- return vit_embeds
198
-
199
- def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
200
- history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
201
- IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
202
- if history is not None or return_history:
203
- print('Now multi-turn chat is not supported in batch_chat.')
204
- raise NotImplementedError
205
-
206
- if image_counts is not None:
207
- num_patches_list = image_counts
208
- print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
209
-
210
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
211
- self.img_context_token_id = img_context_token_id
212
-
213
- if verbose and pixel_values is not None:
214
- image_bs = pixel_values.shape[0]
215
- print(f'dynamic ViT batch size: {image_bs}')
216
-
217
- queries = []
218
- for idx, num_patches in enumerate(num_patches_list):
219
- question = questions[idx]
220
- if pixel_values is not None and '<image>' not in question:
221
- question = '<image>\n' + question
222
- template = get_conv_template(self.template)
223
- template.system_message = self.system_message
224
- template.append_message(template.roles[0], question)
225
- template.append_message(template.roles[1], None)
226
- query = template.get_prompt()
227
-
228
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
229
- query = query.replace('<image>', image_tokens, 1)
230
- queries.append(query)
231
-
232
- tokenizer.padding_side = 'left'
233
- model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
234
- input_ids = model_inputs['input_ids'].cuda()
235
- attention_mask = model_inputs['attention_mask'].cuda()
236
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
237
- generation_config['eos_token_id'] = eos_token_id
238
- generation_output = self.generate(
239
- pixel_values=pixel_values,
240
- input_ids=input_ids,
241
- attention_mask=attention_mask,
242
- **generation_config
243
- )
244
- responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
245
- responses = [response.split(template.sep)[0].strip() for response in responses]
246
- return responses
247
-
248
- def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
249
- num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
250
- verbose=False):
251
-
252
- if history is None and pixel_values is not None and '<image>' not in question:
253
- question = '<image>\n' + question
254
-
255
- if num_patches_list is None:
256
- num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
257
- assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
258
-
259
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
260
- self.img_context_token_id = img_context_token_id
261
-
262
- template = get_conv_template(self.template)
263
- template.system_message = self.system_message
264
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
265
-
266
- history = [] if history is None else history
267
- for (old_question, old_answer) in history:
268
- template.append_message(template.roles[0], old_question)
269
- template.append_message(template.roles[1], old_answer)
270
- template.append_message(template.roles[0], question)
271
- template.append_message(template.roles[1], None)
272
- query = template.get_prompt()
273
-
274
- if verbose and pixel_values is not None:
275
- image_bs = pixel_values.shape[0]
276
- print(f'dynamic ViT batch size: {image_bs}')
277
-
278
- for num_patches in num_patches_list:
279
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
280
- query = query.replace('<image>', image_tokens, 1)
281
-
282
- model_inputs = tokenizer(query, return_tensors='pt')
283
- input_ids = model_inputs['input_ids'].cuda()
284
- attention_mask = model_inputs['attention_mask'].cuda()
285
- generation_config['eos_token_id'] = eos_token_id
286
- generation_output = self.generate(
287
- pixel_values=pixel_values,
288
- input_ids=input_ids,
289
- attention_mask=attention_mask,
290
- **generation_config
291
- )
292
- response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
293
- response = response.split(template.sep)[0].strip()
294
- history.append((question, response))
295
- if return_history:
296
- return response, history
297
- else:
298
- query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
299
- query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
300
- if verbose:
301
- print(query_to_print, response)
302
- return response
303
-
304
- @torch.no_grad()
305
- def generate(
306
- self,
307
- pixel_values: Optional[torch.FloatTensor] = None,
308
- input_ids: Optional[torch.FloatTensor] = None,
309
- attention_mask: Optional[torch.LongTensor] = None,
310
- visual_features: Optional[torch.FloatTensor] = None,
311
- generation_config: Optional[GenerationConfig] = None,
312
- output_hidden_states: Optional[bool] = None,
313
- return_dict: Optional[bool] = None,
314
- **generate_kwargs,
315
- ) -> torch.LongTensor:
316
-
317
- assert self.img_context_token_id is not None
318
- if pixel_values is not None:
319
- if visual_features is not None:
320
- vit_embeds = visual_features
321
- else:
322
- vit_embeds = self.extract_feature(pixel_values)
323
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
324
- B, N, C = input_embeds.shape
325
- input_embeds = input_embeds.reshape(B * N, C)
326
-
327
- input_ids = input_ids.reshape(B * N)
328
- selected = (input_ids == self.img_context_token_id)
329
- assert selected.sum() != 0
330
- input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
331
-
332
- input_embeds = input_embeds.reshape(B, N, C)
333
- else:
334
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
335
-
336
- outputs = self.language_model.generate(
337
- inputs_embeds=input_embeds,
338
- attention_mask=attention_mask,
339
- generation_config=generation_config,
340
- output_hidden_states=output_hidden_states,
341
- return_dict=return_dict,
342
- use_cache=True,
343
- **generate_kwargs,
344
- )
345
-
346
- return outputs