skyipeng commited on
Commit
5481589
·
verified ·
1 Parent(s): dbc9836

Delete modeling_internvsl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvsl_chat.py +0 -454
modeling_internvsl_chat.py DELETED
@@ -1,454 +0,0 @@
1
- import io
2
- import warnings
3
- from typing import List, Optional, Tuple, Union
4
-
5
- import numpy as np
6
- import scipy
7
- import torch.utils.checkpoint
8
- import transformers
9
- from scipy.signal import resample
10
- from torch import nn
11
- from torch.nn import CrossEntropyLoss
12
- from safetensors.torch import load_file
13
-
14
- from transformers import (AutoModel, AutoProcessor, GenerationConfig, LlamaForCausalLM,
15
- Qwen2ForCausalLM)
16
- from transformers.modeling_outputs import CausalLMOutputWithPast
17
- from transformers.modeling_utils import PreTrainedModel
18
- from transformers.utils import ModelOutput, logging
19
-
20
- from .configuration_internvl_chat import InternVLChatConfig
21
- from .conversation import get_conv_template
22
- from .modeling_intern_vit import InternVisionModel, has_flash_attn
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
-
27
- def version_cmp(v1, v2, op='eq'):
28
- import operator
29
-
30
- from packaging import version
31
- op_func = getattr(operator, op)
32
- return op_func(version.parse(v1), version.parse(v2))
33
-
34
-
35
- class AdapterV2(nn.Module):
36
- def __init__(
37
- self,
38
- output_dim: int,
39
- **kwargs,
40
- ):
41
- super().__init__()
42
-
43
- input_dim = 1280
44
- embed_dim = 4096
45
-
46
- self.conv1 = nn.Conv1d(input_dim, input_dim*2, kernel_size=3, stride=2, padding=1)
47
- self.conv2 = nn.Conv1d(input_dim*2, input_dim*4, kernel_size=3, stride=2, padding=1)
48
- self.fc1 = nn.Linear(input_dim*4, embed_dim)
49
- self.fc2 = nn.Linear(embed_dim, output_dim)
50
-
51
- def forward(self, x: torch.Tensor) -> torch.Tensor:
52
- """
53
- x: [B, T, D]
54
- outputs: [B, T//4, D]
55
- """
56
- x = x.transpose(1, 2) # [B, D, T]
57
- outputs = nn.functional.gelu(self.conv1(x))
58
- outputs = nn.functional.gelu(self.conv2(outputs))
59
-
60
- outputs = outputs.transpose(1, 2) # [B, T//4, D]
61
- outputs = self.fc1(outputs)
62
- outputs = nn.functional.relu(outputs)
63
- outputs = self.fc2(outputs)
64
- return outputs
65
-
66
-
67
- class InternVLChatModel(PreTrainedModel):
68
- config_class = InternVLChatConfig
69
- main_input_name = 'pixel_values'
70
- base_model_prefix = 'language_model'
71
- _supports_flash_attn_2 = True
72
- _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
73
-
74
- def __init__(self, config: InternVLChatConfig, vision_model=None, speech_model=None, language_model=None, use_flash_attn=True):
75
- super().__init__(config)
76
-
77
- assert version_cmp(transformers.__version__, '4.37.0', 'ge')
78
- image_size = config.force_image_size or config.vision_config.image_size
79
- patch_size = config.vision_config.patch_size
80
- self.patch_size = patch_size
81
- self.select_layer = config.select_layer
82
- self.template = config.template
83
- self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
84
- self.downsample_ratio = config.downsample_ratio
85
- self.ps_version = config.ps_version
86
- use_flash_attn = use_flash_attn if has_flash_attn else False
87
- config.vision_config.use_flash_attn = True if use_flash_attn else False
88
- config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
89
-
90
- logger.info(f'num_image_token: {self.num_image_token}')
91
- logger.info(f'ps_version: {self.ps_version}')
92
- if vision_model is not None:
93
- self.vision_model = vision_model
94
- else:
95
- self.vision_model = InternVisionModel(config.vision_config)
96
- if speech_model is not None:
97
- self.speech_model = speech_model
98
- else: # ToDo 改成 config.speech_config
99
- speech_encoder_config = transformers.WhisperConfig.from_pretrained(
100
- "openai/whisper-large-v3",
101
- )
102
- self.speech_encoder = transformers.models.whisper.modeling_whisper.WhisperEncoder(speech_encoder_config)
103
- self.speech_encoder.load_state_dict(
104
- load_file(
105
- "/mnt/data/yu.tang/resource/models--openai--whisper-large-v3/encoder.model.safetensors",
106
- )
107
- )
108
-
109
- if language_model is not None:
110
- self.language_model = language_model
111
- else:
112
- if config.llm_config.architectures[0] == 'LlamaForCausalLM':
113
- self.language_model = LlamaForCausalLM(config.llm_config)
114
- elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
115
- self.language_model = Qwen2ForCausalLM(config.llm_config)
116
- else:
117
- raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
118
-
119
- vit_hidden_size = config.vision_config.hidden_size
120
- llm_hidden_size = config.llm_config.hidden_size
121
-
122
- self.speech_feature_extractor = AutoProcessor.from_pretrained(
123
- "openai/whisper-large-v3",
124
- cache_dir=self.speech_feature_extractor_config.cache_dir)
125
- self.speech_adapter = AdapterV2(self.language_model.config.hidden_size)
126
-
127
- self.mlp1 = nn.Sequential(
128
- nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
129
- nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
130
- nn.GELU(),
131
- nn.Linear(llm_hidden_size, llm_hidden_size)
132
- )
133
-
134
- self.img_context_token_id = None
135
- self.conv_template = get_conv_template(self.template)
136
- self.system_message = self.conv_template.system_message
137
-
138
- def forward(
139
- self,
140
- pixel_values: torch.FloatTensor,
141
- input_ids: torch.LongTensor = None,
142
- attention_mask: Optional[torch.Tensor] = None,
143
- position_ids: Optional[torch.LongTensor] = None,
144
- image_flags: Optional[torch.LongTensor] = None,
145
- past_key_values: Optional[List[torch.FloatTensor]] = None,
146
- labels: Optional[torch.LongTensor] = None,
147
- use_cache: Optional[bool] = None,
148
- output_attentions: Optional[bool] = None,
149
- output_hidden_states: Optional[bool] = None,
150
- return_dict: Optional[bool] = None,
151
- ) -> Union[Tuple, CausalLMOutputWithPast]:
152
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
153
-
154
- image_flags = image_flags.squeeze(-1)
155
- input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
156
-
157
- vit_embeds = self.extract_pixel_feature(pixel_values)
158
- vit_embeds = vit_embeds[image_flags == 1]
159
- vit_batch_size = pixel_values.shape[0]
160
-
161
- B, N, C = input_embeds.shape
162
- input_embeds = input_embeds.reshape(B * N, C)
163
-
164
- if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
165
- print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
166
-
167
- input_ids = input_ids.reshape(B * N)
168
- selected = (input_ids == self.img_context_token_id)
169
- try:
170
- input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
171
- except Exception as e:
172
- vit_embeds = vit_embeds.reshape(-1, C)
173
- print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
174
- f'vit_embeds.shape={vit_embeds.shape}')
175
- n_token = selected.sum()
176
- input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
177
-
178
- input_embeds = input_embeds.reshape(B, N, C)
179
-
180
- outputs = self.language_model(
181
- inputs_embeds=input_embeds,
182
- attention_mask=attention_mask,
183
- position_ids=position_ids,
184
- past_key_values=past_key_values,
185
- use_cache=use_cache,
186
- output_attentions=output_attentions,
187
- output_hidden_states=output_hidden_states,
188
- return_dict=return_dict,
189
- )
190
- logits = outputs.logits
191
-
192
- loss = None
193
- if labels is not None:
194
- # Shift so that tokens < n predict n
195
- shift_logits = logits[..., :-1, :].contiguous()
196
- shift_labels = labels[..., 1:].contiguous()
197
- # Flatten the tokens
198
- loss_fct = CrossEntropyLoss()
199
- shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
200
- shift_labels = shift_labels.view(-1)
201
- # Enable model parallelism
202
- shift_labels = shift_labels.to(shift_logits.device)
203
- loss = loss_fct(shift_logits, shift_labels)
204
-
205
- if not return_dict:
206
- output = (logits,) + outputs[1:]
207
- return (loss,) + output if loss is not None else output
208
-
209
- return CausalLMOutputWithPast(
210
- loss=loss,
211
- logits=logits,
212
- past_key_values=outputs.past_key_values,
213
- hidden_states=outputs.hidden_states,
214
- attentions=outputs.attentions,
215
- )
216
-
217
- def pixel_shuffle(self, x, scale_factor=0.5):
218
- n, w, h, c = x.size()
219
- # N, W, H, C --> N, W, H * scale, C // scale
220
- x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
221
- # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
222
- x = x.permute(0, 2, 1, 3).contiguous()
223
- # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
224
- x = x.view(n, int(h * scale_factor), int(w * scale_factor),
225
- int(c / (scale_factor * scale_factor)))
226
- if self.ps_version == 'v1':
227
- warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
228
- 'which results in a transposed image.')
229
- else:
230
- x = x.permute(0, 2, 1, 3).contiguous()
231
- return x
232
-
233
- def extract_pixel_feature(self, pixel_values):
234
- if self.select_layer == -1:
235
- vit_embeds = self.vision_model(
236
- pixel_values=pixel_values,
237
- output_hidden_states=False,
238
- return_dict=True).last_hidden_state
239
- else:
240
- vit_embeds = self.vision_model(
241
- pixel_values=pixel_values,
242
- output_hidden_states=True,
243
- return_dict=True).hidden_states[self.select_layer]
244
- vit_embeds = vit_embeds[:, 1:, :]
245
-
246
- h = w = int(vit_embeds.shape[1] ** 0.5)
247
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
248
- vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
249
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
250
- vit_embeds = self.mlp1(vit_embeds)
251
- return vit_embeds
252
-
253
- @staticmethod
254
- def bytes2wav(wav_bytes):
255
- wav_io = io.BytesIO(wav_bytes)
256
- wav_io.seek(0)
257
- sample_rate, waveform = scipy.io.wavfile.read(wav_io)
258
- return sample_rate, waveform
259
-
260
- def transform_one(self, wav_path):
261
- """
262
- this is for serving
263
- """
264
- sr, audio = self.bytes2wav(wav_path)
265
- audio = (audio.astype(np.float32, order='C') / 32768.0)
266
- audio = torch.from_numpy(audio)[None, :]
267
-
268
- # Resample to 16000 Hz
269
- target_sr = 16000
270
- if sr != target_sr:
271
- num_samples = round(len(audio) * float(target_sr) / sr)
272
- audio = resample(audio, num_samples)
273
-
274
- # audio -> mel
275
- speech_input = self.speech_feature_extractor(
276
- audio=audio,
277
- **self.speech_feature_extractor_config.call_kwargs
278
- )
279
- mel = speech_input.input_features
280
- # mel_length = speech_input.attention_mask.sum(dim=1)
281
-
282
- speech_encoder_outputs = self.speech_encoder(mel, return_dict=True)
283
- speech_encoder_hidden_states = speech_encoder_outputs.last_hidden_state
284
- return speech_encoder_hidden_states
285
- # # TODO: might need mask later, subsampling
286
- # speech_embeds = self.adapter(speech_encoder_hidden_states)
287
- #
288
- # return speech_embeds
289
-
290
- # # text -> text token
291
- # text_prefix = "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n"
292
- #
293
- # text_prefix_token = tokenizer(text_prefix).input_ids
294
- # text_suffix = "<|im_end|>\n<|im_start|>user\n"
295
- # text_suffix_token = tokenizer(text_suffix).input_ids
296
- #
297
- # return {
298
- # "__key__": item["__key__"],
299
- # "mel" : mel,
300
- # "mel_length" : mel_length,
301
- # "text_prefix" : text_prefix,
302
- # "text_prefix_token" : text_prefix_token,
303
- # "text_suffix" : text_suffix,
304
- # "text_suffix_token" : text_suffix_token,
305
- # "task" : task,
306
- # }
307
-
308
- def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
309
- history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
310
- IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
311
- if history is not None or return_history:
312
- print('Now multi-turn chat is not supported in batch_chat.')
313
- raise NotImplementedError
314
-
315
- if image_counts is not None:
316
- num_patches_list = image_counts
317
- print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
318
-
319
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
320
- self.img_context_token_id = img_context_token_id
321
-
322
- if verbose and pixel_values is not None:
323
- image_bs = pixel_values.shape[0]
324
- print(f'dynamic ViT batch size: {image_bs}')
325
-
326
- queries = []
327
- for idx, num_patches in enumerate(num_patches_list):
328
- question = questions[idx]
329
- if pixel_values is not None and '<image>' not in question:
330
- question = '<image>\n' + question
331
- template = get_conv_template(self.template)
332
- template.system_message = self.system_message
333
- template.append_message(template.roles[0], question)
334
- template.append_message(template.roles[1], None)
335
- query = template.get_prompt()
336
-
337
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
338
- query = query.replace('<image>', image_tokens, 1)
339
- queries.append(query)
340
-
341
- tokenizer.padding_side = 'left'
342
- model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
343
- input_ids = model_inputs['input_ids'].to(self.device)
344
- attention_mask = model_inputs['attention_mask'].to(self.device)
345
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
346
- generation_config['eos_token_id'] = eos_token_id
347
- generation_output = self.generate(
348
- pixel_values=pixel_values,
349
- input_ids=input_ids,
350
- attention_mask=attention_mask,
351
- **generation_config
352
- )
353
- responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
354
- responses = [response.split(template.sep.strip())[0].strip() for response in responses]
355
- return responses
356
-
357
- def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
358
- num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
359
- verbose=False):
360
-
361
- if history is None and pixel_values is not None and '<image>' not in question:
362
- question = '<image>\n' + question
363
-
364
- if num_patches_list is None:
365
- num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
366
- assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
367
-
368
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
369
- self.img_context_token_id = img_context_token_id
370
-
371
- template = get_conv_template(self.template)
372
- template.system_message = self.system_message
373
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
374
-
375
- history = [] if history is None else history
376
- for (old_question, old_answer) in history:
377
- template.append_message(template.roles[0], old_question)
378
- template.append_message(template.roles[1], old_answer)
379
- template.append_message(template.roles[0], question)
380
- template.append_message(template.roles[1], None)
381
- query = template.get_prompt()
382
-
383
- if verbose and pixel_values is not None:
384
- image_bs = pixel_values.shape[0]
385
- print(f'dynamic ViT batch size: {image_bs}')
386
-
387
- for num_patches in num_patches_list:
388
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
389
- query = query.replace('<image>', image_tokens, 1)
390
-
391
- model_inputs = tokenizer(query, return_tensors='pt')
392
- input_ids = model_inputs['input_ids'].to(self.device)
393
- attention_mask = model_inputs['attention_mask'].to(self.device)
394
- generation_config['eos_token_id'] = eos_token_id
395
- generation_output = self.generate(
396
- pixel_values=pixel_values,
397
- input_ids=input_ids,
398
- attention_mask=attention_mask,
399
- **generation_config
400
- )
401
- response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
402
- response = response.split(template.sep.strip())[0].strip()
403
- history.append((question, response))
404
- if return_history:
405
- return response, history
406
- else:
407
- query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
408
- query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
409
- if verbose:
410
- print(query_to_print, response)
411
- return response
412
-
413
- @torch.no_grad()
414
- def generate(
415
- self,
416
- pixel_values: Optional[torch.FloatTensor] = None,
417
- wav_path: Optional[str] = None,
418
- input_ids: Optional[torch.FloatTensor] = None,
419
- attention_mask: Optional[torch.LongTensor] = None,
420
- visual_features: Optional[torch.FloatTensor] = None,
421
- generation_config: Optional[GenerationConfig] = None,
422
- output_hidden_states: Optional[bool] = None,
423
- **generate_kwargs,
424
- ) -> torch.LongTensor:
425
-
426
- assert self.img_context_token_id is not None
427
- if pixel_values is not None:
428
- if visual_features is not None:
429
- vit_embeds = visual_features
430
- else:
431
- vit_embeds = self.extract_pixel_feature(pixel_values)
432
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
433
- B, N, C = input_embeds.shape
434
- input_embeds = input_embeds.reshape(B * N, C)
435
-
436
- input_ids = input_ids.reshape(B * N)
437
- selected = (input_ids == self.img_context_token_id)
438
- assert selected.sum() != 0
439
- input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
440
-
441
- input_embeds = input_embeds.reshape(B, N, C)
442
- else:
443
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
444
-
445
- outputs = self.language_model.generate(
446
- inputs_embeds=input_embeds,
447
- attention_mask=attention_mask,
448
- generation_config=generation_config,
449
- output_hidden_states=output_hidden_states,
450
- use_cache=True,
451
- **generate_kwargs,
452
- )
453
-
454
- return outputs