skyipeng commited on
Commit
dbc9836
·
verified ·
1 Parent(s): 89dba90

Upload modeling_internvsl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvsl_chat.py +454 -0
modeling_internvsl_chat.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import warnings
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import scipy
7
+ import torch.utils.checkpoint
8
+ import transformers
9
+ from scipy.signal import resample
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from safetensors.torch import load_file
13
+
14
+ from transformers import (AutoModel, AutoProcessor, GenerationConfig, LlamaForCausalLM,
15
+ Qwen2ForCausalLM)
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import ModelOutput, logging
19
+
20
+ from .configuration_internvl_chat import InternVLChatConfig
21
+ from .conversation import get_conv_template
22
+ from .modeling_intern_vit import InternVisionModel, has_flash_attn
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def version_cmp(v1, v2, op='eq'):
28
+ import operator
29
+
30
+ from packaging import version
31
+ op_func = getattr(operator, op)
32
+ return op_func(version.parse(v1), version.parse(v2))
33
+
34
+
35
+ class AdapterV2(nn.Module):
36
+ def __init__(
37
+ self,
38
+ output_dim: int,
39
+ **kwargs,
40
+ ):
41
+ super().__init__()
42
+
43
+ input_dim = 1280
44
+ embed_dim = 4096
45
+
46
+ self.conv1 = nn.Conv1d(input_dim, input_dim*2, kernel_size=3, stride=2, padding=1)
47
+ self.conv2 = nn.Conv1d(input_dim*2, input_dim*4, kernel_size=3, stride=2, padding=1)
48
+ self.fc1 = nn.Linear(input_dim*4, embed_dim)
49
+ self.fc2 = nn.Linear(embed_dim, output_dim)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ x: [B, T, D]
54
+ outputs: [B, T//4, D]
55
+ """
56
+ x = x.transpose(1, 2) # [B, D, T]
57
+ outputs = nn.functional.gelu(self.conv1(x))
58
+ outputs = nn.functional.gelu(self.conv2(outputs))
59
+
60
+ outputs = outputs.transpose(1, 2) # [B, T//4, D]
61
+ outputs = self.fc1(outputs)
62
+ outputs = nn.functional.relu(outputs)
63
+ outputs = self.fc2(outputs)
64
+ return outputs
65
+
66
+
67
+ class InternVLChatModel(PreTrainedModel):
68
+ config_class = InternVLChatConfig
69
+ main_input_name = 'pixel_values'
70
+ base_model_prefix = 'language_model'
71
+ _supports_flash_attn_2 = True
72
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
73
+
74
+ def __init__(self, config: InternVLChatConfig, vision_model=None, speech_model=None, language_model=None, use_flash_attn=True):
75
+ super().__init__(config)
76
+
77
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
78
+ image_size = config.force_image_size or config.vision_config.image_size
79
+ patch_size = config.vision_config.patch_size
80
+ self.patch_size = patch_size
81
+ self.select_layer = config.select_layer
82
+ self.template = config.template
83
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
84
+ self.downsample_ratio = config.downsample_ratio
85
+ self.ps_version = config.ps_version
86
+ use_flash_attn = use_flash_attn if has_flash_attn else False
87
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
88
+ config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
89
+
90
+ logger.info(f'num_image_token: {self.num_image_token}')
91
+ logger.info(f'ps_version: {self.ps_version}')
92
+ if vision_model is not None:
93
+ self.vision_model = vision_model
94
+ else:
95
+ self.vision_model = InternVisionModel(config.vision_config)
96
+ if speech_model is not None:
97
+ self.speech_model = speech_model
98
+ else: # ToDo 改成 config.speech_config
99
+ speech_encoder_config = transformers.WhisperConfig.from_pretrained(
100
+ "openai/whisper-large-v3",
101
+ )
102
+ self.speech_encoder = transformers.models.whisper.modeling_whisper.WhisperEncoder(speech_encoder_config)
103
+ self.speech_encoder.load_state_dict(
104
+ load_file(
105
+ "/mnt/data/yu.tang/resource/models--openai--whisper-large-v3/encoder.model.safetensors",
106
+ )
107
+ )
108
+
109
+ if language_model is not None:
110
+ self.language_model = language_model
111
+ else:
112
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
113
+ self.language_model = LlamaForCausalLM(config.llm_config)
114
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
115
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
116
+ else:
117
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
118
+
119
+ vit_hidden_size = config.vision_config.hidden_size
120
+ llm_hidden_size = config.llm_config.hidden_size
121
+
122
+ self.speech_feature_extractor = AutoProcessor.from_pretrained(
123
+ "openai/whisper-large-v3",
124
+ cache_dir=self.speech_feature_extractor_config.cache_dir)
125
+ self.speech_adapter = AdapterV2(self.language_model.config.hidden_size)
126
+
127
+ self.mlp1 = nn.Sequential(
128
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
129
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
130
+ nn.GELU(),
131
+ nn.Linear(llm_hidden_size, llm_hidden_size)
132
+ )
133
+
134
+ self.img_context_token_id = None
135
+ self.conv_template = get_conv_template(self.template)
136
+ self.system_message = self.conv_template.system_message
137
+
138
+ def forward(
139
+ self,
140
+ pixel_values: torch.FloatTensor,
141
+ input_ids: torch.LongTensor = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ position_ids: Optional[torch.LongTensor] = None,
144
+ image_flags: Optional[torch.LongTensor] = None,
145
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
146
+ labels: Optional[torch.LongTensor] = None,
147
+ use_cache: Optional[bool] = None,
148
+ output_attentions: Optional[bool] = None,
149
+ output_hidden_states: Optional[bool] = None,
150
+ return_dict: Optional[bool] = None,
151
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
152
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
153
+
154
+ image_flags = image_flags.squeeze(-1)
155
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
156
+
157
+ vit_embeds = self.extract_pixel_feature(pixel_values)
158
+ vit_embeds = vit_embeds[image_flags == 1]
159
+ vit_batch_size = pixel_values.shape[0]
160
+
161
+ B, N, C = input_embeds.shape
162
+ input_embeds = input_embeds.reshape(B * N, C)
163
+
164
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
165
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
166
+
167
+ input_ids = input_ids.reshape(B * N)
168
+ selected = (input_ids == self.img_context_token_id)
169
+ try:
170
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
171
+ except Exception as e:
172
+ vit_embeds = vit_embeds.reshape(-1, C)
173
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
174
+ f'vit_embeds.shape={vit_embeds.shape}')
175
+ n_token = selected.sum()
176
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
177
+
178
+ input_embeds = input_embeds.reshape(B, N, C)
179
+
180
+ outputs = self.language_model(
181
+ inputs_embeds=input_embeds,
182
+ attention_mask=attention_mask,
183
+ position_ids=position_ids,
184
+ past_key_values=past_key_values,
185
+ use_cache=use_cache,
186
+ output_attentions=output_attentions,
187
+ output_hidden_states=output_hidden_states,
188
+ return_dict=return_dict,
189
+ )
190
+ logits = outputs.logits
191
+
192
+ loss = None
193
+ if labels is not None:
194
+ # Shift so that tokens < n predict n
195
+ shift_logits = logits[..., :-1, :].contiguous()
196
+ shift_labels = labels[..., 1:].contiguous()
197
+ # Flatten the tokens
198
+ loss_fct = CrossEntropyLoss()
199
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
200
+ shift_labels = shift_labels.view(-1)
201
+ # Enable model parallelism
202
+ shift_labels = shift_labels.to(shift_logits.device)
203
+ loss = loss_fct(shift_logits, shift_labels)
204
+
205
+ if not return_dict:
206
+ output = (logits,) + outputs[1:]
207
+ return (loss,) + output if loss is not None else output
208
+
209
+ return CausalLMOutputWithPast(
210
+ loss=loss,
211
+ logits=logits,
212
+ past_key_values=outputs.past_key_values,
213
+ hidden_states=outputs.hidden_states,
214
+ attentions=outputs.attentions,
215
+ )
216
+
217
+ def pixel_shuffle(self, x, scale_factor=0.5):
218
+ n, w, h, c = x.size()
219
+ # N, W, H, C --> N, W, H * scale, C // scale
220
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
221
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
222
+ x = x.permute(0, 2, 1, 3).contiguous()
223
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
224
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
225
+ int(c / (scale_factor * scale_factor)))
226
+ if self.ps_version == 'v1':
227
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
228
+ 'which results in a transposed image.')
229
+ else:
230
+ x = x.permute(0, 2, 1, 3).contiguous()
231
+ return x
232
+
233
+ def extract_pixel_feature(self, pixel_values):
234
+ if self.select_layer == -1:
235
+ vit_embeds = self.vision_model(
236
+ pixel_values=pixel_values,
237
+ output_hidden_states=False,
238
+ return_dict=True).last_hidden_state
239
+ else:
240
+ vit_embeds = self.vision_model(
241
+ pixel_values=pixel_values,
242
+ output_hidden_states=True,
243
+ return_dict=True).hidden_states[self.select_layer]
244
+ vit_embeds = vit_embeds[:, 1:, :]
245
+
246
+ h = w = int(vit_embeds.shape[1] ** 0.5)
247
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
248
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
249
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
250
+ vit_embeds = self.mlp1(vit_embeds)
251
+ return vit_embeds
252
+
253
+ @staticmethod
254
+ def bytes2wav(wav_bytes):
255
+ wav_io = io.BytesIO(wav_bytes)
256
+ wav_io.seek(0)
257
+ sample_rate, waveform = scipy.io.wavfile.read(wav_io)
258
+ return sample_rate, waveform
259
+
260
+ def transform_one(self, wav_path):
261
+ """
262
+ this is for serving
263
+ """
264
+ sr, audio = self.bytes2wav(wav_path)
265
+ audio = (audio.astype(np.float32, order='C') / 32768.0)
266
+ audio = torch.from_numpy(audio)[None, :]
267
+
268
+ # Resample to 16000 Hz
269
+ target_sr = 16000
270
+ if sr != target_sr:
271
+ num_samples = round(len(audio) * float(target_sr) / sr)
272
+ audio = resample(audio, num_samples)
273
+
274
+ # audio -> mel
275
+ speech_input = self.speech_feature_extractor(
276
+ audio=audio,
277
+ **self.speech_feature_extractor_config.call_kwargs
278
+ )
279
+ mel = speech_input.input_features
280
+ # mel_length = speech_input.attention_mask.sum(dim=1)
281
+
282
+ speech_encoder_outputs = self.speech_encoder(mel, return_dict=True)
283
+ speech_encoder_hidden_states = speech_encoder_outputs.last_hidden_state
284
+ return speech_encoder_hidden_states
285
+ # # TODO: might need mask later, subsampling
286
+ # speech_embeds = self.adapter(speech_encoder_hidden_states)
287
+ #
288
+ # return speech_embeds
289
+
290
+ # # text -> text token
291
+ # text_prefix = "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n"
292
+ #
293
+ # text_prefix_token = tokenizer(text_prefix).input_ids
294
+ # text_suffix = "<|im_end|>\n<|im_start|>user\n"
295
+ # text_suffix_token = tokenizer(text_suffix).input_ids
296
+ #
297
+ # return {
298
+ # "__key__": item["__key__"],
299
+ # "mel" : mel,
300
+ # "mel_length" : mel_length,
301
+ # "text_prefix" : text_prefix,
302
+ # "text_prefix_token" : text_prefix_token,
303
+ # "text_suffix" : text_suffix,
304
+ # "text_suffix_token" : text_suffix_token,
305
+ # "task" : task,
306
+ # }
307
+
308
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
309
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
310
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
311
+ if history is not None or return_history:
312
+ print('Now multi-turn chat is not supported in batch_chat.')
313
+ raise NotImplementedError
314
+
315
+ if image_counts is not None:
316
+ num_patches_list = image_counts
317
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
318
+
319
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
320
+ self.img_context_token_id = img_context_token_id
321
+
322
+ if verbose and pixel_values is not None:
323
+ image_bs = pixel_values.shape[0]
324
+ print(f'dynamic ViT batch size: {image_bs}')
325
+
326
+ queries = []
327
+ for idx, num_patches in enumerate(num_patches_list):
328
+ question = questions[idx]
329
+ if pixel_values is not None and '<image>' not in question:
330
+ question = '<image>\n' + question
331
+ template = get_conv_template(self.template)
332
+ template.system_message = self.system_message
333
+ template.append_message(template.roles[0], question)
334
+ template.append_message(template.roles[1], None)
335
+ query = template.get_prompt()
336
+
337
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
338
+ query = query.replace('<image>', image_tokens, 1)
339
+ queries.append(query)
340
+
341
+ tokenizer.padding_side = 'left'
342
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
343
+ input_ids = model_inputs['input_ids'].to(self.device)
344
+ attention_mask = model_inputs['attention_mask'].to(self.device)
345
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
346
+ generation_config['eos_token_id'] = eos_token_id
347
+ generation_output = self.generate(
348
+ pixel_values=pixel_values,
349
+ input_ids=input_ids,
350
+ attention_mask=attention_mask,
351
+ **generation_config
352
+ )
353
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
354
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
355
+ return responses
356
+
357
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
358
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
359
+ verbose=False):
360
+
361
+ if history is None and pixel_values is not None and '<image>' not in question:
362
+ question = '<image>\n' + question
363
+
364
+ if num_patches_list is None:
365
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
366
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
367
+
368
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
369
+ self.img_context_token_id = img_context_token_id
370
+
371
+ template = get_conv_template(self.template)
372
+ template.system_message = self.system_message
373
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
374
+
375
+ history = [] if history is None else history
376
+ for (old_question, old_answer) in history:
377
+ template.append_message(template.roles[0], old_question)
378
+ template.append_message(template.roles[1], old_answer)
379
+ template.append_message(template.roles[0], question)
380
+ template.append_message(template.roles[1], None)
381
+ query = template.get_prompt()
382
+
383
+ if verbose and pixel_values is not None:
384
+ image_bs = pixel_values.shape[0]
385
+ print(f'dynamic ViT batch size: {image_bs}')
386
+
387
+ for num_patches in num_patches_list:
388
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
389
+ query = query.replace('<image>', image_tokens, 1)
390
+
391
+ model_inputs = tokenizer(query, return_tensors='pt')
392
+ input_ids = model_inputs['input_ids'].to(self.device)
393
+ attention_mask = model_inputs['attention_mask'].to(self.device)
394
+ generation_config['eos_token_id'] = eos_token_id
395
+ generation_output = self.generate(
396
+ pixel_values=pixel_values,
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ **generation_config
400
+ )
401
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
402
+ response = response.split(template.sep.strip())[0].strip()
403
+ history.append((question, response))
404
+ if return_history:
405
+ return response, history
406
+ else:
407
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
408
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
409
+ if verbose:
410
+ print(query_to_print, response)
411
+ return response
412
+
413
+ @torch.no_grad()
414
+ def generate(
415
+ self,
416
+ pixel_values: Optional[torch.FloatTensor] = None,
417
+ wav_path: Optional[str] = None,
418
+ input_ids: Optional[torch.FloatTensor] = None,
419
+ attention_mask: Optional[torch.LongTensor] = None,
420
+ visual_features: Optional[torch.FloatTensor] = None,
421
+ generation_config: Optional[GenerationConfig] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ **generate_kwargs,
424
+ ) -> torch.LongTensor:
425
+
426
+ assert self.img_context_token_id is not None
427
+ if pixel_values is not None:
428
+ if visual_features is not None:
429
+ vit_embeds = visual_features
430
+ else:
431
+ vit_embeds = self.extract_pixel_feature(pixel_values)
432
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
433
+ B, N, C = input_embeds.shape
434
+ input_embeds = input_embeds.reshape(B * N, C)
435
+
436
+ input_ids = input_ids.reshape(B * N)
437
+ selected = (input_ids == self.img_context_token_id)
438
+ assert selected.sum() != 0
439
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
440
+
441
+ input_embeds = input_embeds.reshape(B, N, C)
442
+ else:
443
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
444
+
445
+ outputs = self.language_model.generate(
446
+ inputs_embeds=input_embeds,
447
+ attention_mask=attention_mask,
448
+ generation_config=generation_config,
449
+ output_hidden_states=output_hidden_states,
450
+ use_cache=True,
451
+ **generate_kwargs,
452
+ )
453
+
454
+ return outputs