h0witended commited on
Commit
7699842
·
verified ·
1 Parent(s): 60ca1cc

Delete processing_minicpmo.py

Browse files
Files changed (1) hide show
  1. processing_minicpmo.py +0 -505
processing_minicpmo.py DELETED
@@ -1,505 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 The OpenBMB Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Processor class for MiniCPMO.
17
- """
18
-
19
- import math
20
- import re
21
- from typing import List
22
- from typing import Literal
23
- from typing import Optional
24
- from typing import Union
25
-
26
- import numpy as np
27
- import torch
28
- import torchaudio
29
- from transformers.image_utils import ImageInput
30
- from transformers.processing_utils import ProcessorMixin
31
- from transformers.tokenization_utils_base import PreTokenizedInput
32
- from transformers.tokenization_utils_base import TextInput
33
- from transformers.utils import TensorType
34
-
35
- from .image_processing_minicpmv import MiniCPMOBatchFeature
36
-
37
-
38
- class MiniCPMOProcessor(ProcessorMixin):
39
- r"""
40
- Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor.
41
-
42
- [`MiniCPMVProcessor`] offers all the functionalities of [`MiniCPMVImageProcessor`] and [`LlamaTokenizerWrapper`]. See the
43
- [`~MiniCPMVProcessor.__call__`] and [`~MiniCPMVProcessor.decode`] for more information.
44
-
45
- Args:
46
- image_processor ([`MiniCPMVImageProcessor`], *optional*):
47
- The image processor is a required input.
48
- tokenizer ([`LlamaTokenizerWrapper`], *optional*):
49
- The tokenizer is a required input.
50
- """
51
-
52
- attributes = ["image_processor", "feature_extractor", "tokenizer"]
53
- feature_extractor_class = "WhisperFeatureExtractor"
54
- image_processor_class = "AutoImageProcessor"
55
- tokenizer_class = "AutoTokenizer"
56
-
57
- def __init__(self, image_processor=None, feature_extractor=None, tokenizer=None):
58
- super().__init__(image_processor, feature_extractor, tokenizer)
59
- self.version = image_processor.version
60
-
61
- def __call__(
62
- self,
63
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
64
- images: ImageInput = None,
65
- audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]] = None,
66
- audio_parts: Optional[list] = None,
67
- max_length: Optional[int] = None,
68
- do_pad: Optional[bool] = True,
69
- max_slice_nums: int = None,
70
- use_image_id: bool = True,
71
- chunk_input: bool = False,
72
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
73
- sampling_rate: Optional[int] = 16000,
74
- **kwargs,
75
- ) -> MiniCPMOBatchFeature:
76
- if images is not None:
77
- image_inputs = self.image_processor(
78
- images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
79
- )
80
- else:
81
- image_inputs = None
82
-
83
- if audios is not None:
84
- audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract(
85
- audios, audio_parts, chunk_input, sampling_rate
86
- )
87
- else:
88
- audio_features, audio_feature_lens, audio_phs = [], [], []
89
-
90
- model_inputs = self._convert_omni_to_inputs(
91
- image_inputs,
92
- audio_phs,
93
- text,
94
- max_slice_nums=max_slice_nums,
95
- use_image_id=use_image_id,
96
- max_length=max_length,
97
- **kwargs,
98
- )
99
-
100
- model_inputs["audio_features"] = audio_features
101
- model_inputs["audio_feature_lens"] = audio_feature_lens
102
-
103
- return MiniCPMOBatchFeature(data={**model_inputs})
104
-
105
- def get_audio_placeholder(self, audio_lens, chunk_input, chunk_length):
106
- pool_step = 2
107
- feature_lens = math.ceil(audio_lens / self.feature_extractor.hop_length)
108
-
109
- feature_lens = (feature_lens - 1) // 2 + 1
110
- output_lens = (feature_lens - pool_step) // pool_step + 1
111
-
112
- if chunk_input:
113
- fbank_feat_in_chunk = int(chunk_length * 100)
114
- cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
115
- audio_embeds_in_chunk = (cnn_feat_in_chunk - pool_step) // pool_step + 1
116
- num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk
117
-
118
- place_holders = ""
119
- total_unk_len = 0
120
- for _ in range(num_audio_chunks):
121
- unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
122
- place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
123
- total_unk_len += unk_len
124
- audio_placeholder = place_holders
125
- else:
126
- audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
127
-
128
- return audio_placeholder
129
-
130
- def audio_feature_extract(
131
- self,
132
- audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]],
133
- audio_parts: Optional[list] = None,
134
- chunk_input: Optional[bool] = False,
135
- sampling_rate: Optional[int] = None,
136
- chunk_length: Optional[int] = 1,
137
- **kwargs,
138
- ):
139
- if isinstance(audios, np.ndarray):
140
- audios_list = [[audios]]
141
- elif isinstance(audios[0], np.ndarray):
142
- audios_list = [audios]
143
- else:
144
- audios_list = audios
145
-
146
- if audio_parts is not None:
147
- assert len(audio_parts) == len(audios_list)
148
- for parts, audios in zip(audio_parts, audios_list):
149
- assert len(parts) == len(audios)
150
-
151
- audio_feature_lens_list = []
152
- audio_ph_list = []
153
-
154
- audio_features_all = []
155
-
156
- # audio placeholder not dependent on audio_parts
157
- for audios in audios_list:
158
- if audios:
159
- audio_ph_list.append([self.get_audio_placeholder(len(a), chunk_input, chunk_length) for a in audios])
160
- else:
161
- audio_ph_list.append([])
162
-
163
- for idx, audios in enumerate(audios_list):
164
- if audio_parts is not None:
165
- # same audio part merge
166
- audio_part = audio_parts[idx]
167
- merge_audio = []
168
- cur_audio = []
169
- for aid, (part, audio) in enumerate(zip(audio_part, audios)):
170
- if aid == 0 or audio_part[aid] == audio_part[aid - 1]:
171
- cur_audio.append(audio)
172
- else:
173
- merge_audio.append(np.hstack(cur_audio))
174
- cur_audio = [audio]
175
- if cur_audio:
176
- merge_audio.append(np.hstack(cur_audio))
177
-
178
- else:
179
- merge_audio = audios
180
-
181
- audio_feature_lens = []
182
-
183
- # If the audio exceeds 30 seconds, split it into chunks every 30 seconds.
184
- final_merge_audio = []
185
- max_audio_inp_len = 30 * sampling_rate
186
- for audio in merge_audio:
187
- if len(audio) <= max_audio_inp_len:
188
- final_merge_audio.append(audio)
189
- else:
190
- for i in range(math.ceil(len(audio) / max_audio_inp_len)):
191
- final_merge_audio.append(audio[i * max_audio_inp_len : (i + 1) * max_audio_inp_len])
192
-
193
- if audios:
194
- audio_inputs = self.feature_extractor(
195
- final_merge_audio,
196
- sampling_rate=sampling_rate,
197
- return_attention_mask=True,
198
- padding="max_length",
199
- return_tensors="pt",
200
- **kwargs,
201
- )
202
- audio_feature = audio_inputs["input_features"]
203
- actual_lens = audio_inputs["attention_mask"].sum(dim=1)
204
-
205
- for feat, lens in zip(audio_feature, actual_lens):
206
- audio_features_all.append(feat[:, :lens])
207
- audio_feature_lens.append(lens)
208
-
209
- audio_feature_lens = torch.hstack(audio_feature_lens)
210
- audio_feature_lens_list.append(audio_feature_lens)
211
- else:
212
- audio_feature_lens_list.append([])
213
-
214
- if audio_features_all:
215
- audio_features = [i.permute(1, 0) for i in audio_features_all]
216
- audio_features = torch.nn.utils.rnn.pad_sequence(
217
- audio_features, batch_first=True, padding_value=0.0
218
- ).permute(0, 2, 1)
219
- else:
220
- audio_features = []
221
-
222
- return audio_features, audio_feature_lens_list, audio_ph_list
223
-
224
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
225
- def batch_decode(self, *args, **kwargs):
226
- """
227
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
228
- refer to the docstring of this method for more information.
229
- """
230
- output_ids = args[0]
231
- result_text = []
232
- for result in output_ids:
233
- result = result[result != 0]
234
- if result[0] == self.tokenizer.bos_id:
235
- result = result[1:]
236
- if result[-1] == self.tokenizer.eos_id:
237
- result = result[:-1]
238
- result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
239
- return result_text
240
- # return self.tokenizer.batch_decode(*args, **kwargs)
241
-
242
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
243
- def decode(self, *args, **kwargs):
244
- """
245
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
246
- the docstring of this method for more information.
247
- """
248
- result = args[0]
249
- result = result[result != 0]
250
- if result[0] == self.tokenizer.bos_id:
251
- result = result[1:]
252
- if result[-1] == self.tokenizer.eos_id or (
253
- hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
254
- ):
255
- result = result[:-1]
256
- return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
257
-
258
- def _convert(self, input_str, max_inp_length: Optional[int] = None, **kwargs):
259
- input_ids = self.tokenizer.encode(input_str, **kwargs)
260
- if max_inp_length is not None:
261
- input_ids = input_ids[:max_inp_length]
262
- input_ids = torch.tensor(input_ids, dtype=torch.int32)
263
-
264
- ## image bound
265
- start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
266
- end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
267
-
268
- image_start_idx = torch.where(start_cond)[0]
269
- image_start_idx += 1
270
- image_end_idx = torch.where(end_cond)[0]
271
-
272
- valid_image_nums = max(len(image_start_idx), len(image_end_idx))
273
-
274
- image_bounds = torch.hstack(
275
- [
276
- image_start_idx[:valid_image_nums].unsqueeze(-1),
277
- image_end_idx[:valid_image_nums].unsqueeze(-1),
278
- ]
279
- )
280
-
281
- ## audio bound
282
- audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
283
- audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
284
- assert len(audio_start_idx) == len(audio_end_idx)
285
- audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
286
-
287
- spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
288
- spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
289
- assert len(spk_start_idx) == len(spk_end_idx)
290
- spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
291
-
292
- return input_ids, image_bounds, audio_bounds, spk_bounds
293
-
294
- def _convert_omni_to_inputs(
295
- self,
296
- images,
297
- audio_phs,
298
- texts: Union[str, List[str]],
299
- truncation=None,
300
- max_length=None,
301
- max_slice_nums=None,
302
- use_image_id=None,
303
- return_tensors=None,
304
- **kwargs,
305
- ):
306
- if images is None and audio_phs is None:
307
- model_inputs = self.tokenizer(
308
- texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
309
- )
310
- return MiniCPMOBatchFeature(data={**model_inputs})
311
-
312
- image_tag = "(<image>./</image>)"
313
- image_pattern = "\(<image>./</image>\)"
314
- audio_tag = "(<audio>./</audio>)"
315
- audio_pattern = "\(<audio>./</audio>\)"
316
- split_pattern = f"({image_pattern}|{audio_pattern})"
317
-
318
- if isinstance(texts, str):
319
- texts = [texts]
320
-
321
- bs = len(texts)
322
- if images is not None:
323
- images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
324
- else:
325
- images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs
326
-
327
- input_ids_list = []
328
- image_bounds_list = []
329
- audio_bounds_list = []
330
- spk_bounds_list = []
331
-
332
- for index, text in enumerate(texts):
333
- text_chunks = re.split(split_pattern, text)
334
-
335
- image_tags = re.findall(image_pattern, text)
336
- audio_tags = re.findall(audio_pattern, text)
337
-
338
- if image_tags:
339
- assert images is not None
340
- assert len(image_tags) == len(image_sizes[index])
341
- if audio_tags:
342
- assert audio_phs is not None
343
- assert len(audio_tags) == len(audio_phs[index])
344
-
345
- image_id = 0
346
- audio_id = 0
347
- for i, chunk in enumerate(text_chunks):
348
- if chunk == image_tag:
349
- image_placeholder = self.image_processor.get_slice_image_placeholder(
350
- image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
351
- )
352
- image_id += 1
353
- text_chunks[i] = image_placeholder
354
- elif chunk == audio_tag:
355
- audio_placeholder = audio_phs[index][audio_id]
356
- audio_id += 1
357
- text_chunks[i] = audio_placeholder
358
-
359
- final_text = "".join(text_chunks)
360
- input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length, **kwargs)
361
-
362
- input_ids_list.append(input_ids)
363
- image_bounds_list.append(image_bounds)
364
- audio_bounds_list.append(audio_bounds)
365
- spk_bounds_list.append(spk_bounds)
366
-
367
- padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
368
- attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool)
369
- for i, length in enumerate(padding_lengths):
370
- image_bounds_list[i] = image_bounds_list[i] + length
371
- audio_bounds_list[i] = audio_bounds_list[i] + length
372
- spk_bounds_list[i] = spk_bounds_list[i] + length
373
- attention_mask[i, :length] = False
374
-
375
- data = {
376
- "input_ids": padded_input_ids,
377
- "attention_mask": attention_mask,
378
- "pixel_values": images,
379
- "image_sizes": image_sizes,
380
- "image_bound": image_bounds_list,
381
- "tgt_sizes": tgt_sizes,
382
- "audio_bounds": audio_bounds_list,
383
- "spk_bounds": spk_bounds_list,
384
- }
385
-
386
- return data
387
-
388
- @property
389
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
390
- def model_input_names(self):
391
- tokenizer_input_names = self.tokenizer.model_input_names
392
- image_processor_input_names = self.image_processor.model_input_names
393
- feature_extractor_input_names = self.feature_extractor.model_input_names
394
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extractor_input_names))
395
-
396
- def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
397
- items = []
398
- if isinstance(inputs[0], list):
399
- assert isinstance(inputs[0][0], torch.Tensor)
400
- for it in inputs:
401
- for tr in it:
402
- items.append(tr)
403
- else:
404
- assert isinstance(inputs[0], torch.Tensor)
405
- items = inputs
406
-
407
- batch_size = len(items)
408
- shape = items[0].shape
409
- dim = len(shape)
410
- assert dim <= 2
411
- if max_length is None:
412
- max_length = 0
413
- max_length = max(max_length, max(item.shape[-1] for item in items))
414
- min_length = min(item.shape[-1] for item in items)
415
- dtype = items[0].dtype
416
-
417
- if dim == 0:
418
- return torch.stack([item for item in items], dim=0), [0]
419
- elif dim == 1:
420
- if max_length == min_length:
421
- return torch.stack([item for item in items], dim=0), [0] * batch_size
422
- tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
423
- else:
424
- tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
425
-
426
- padding_length = []
427
- for i, item in enumerate(items):
428
- if dim == 1:
429
- if padding_side == "left":
430
- tensor[i, -len(item) :] = item.clone()
431
- else:
432
- tensor[i, : len(item)] = item.clone()
433
- elif dim == 2:
434
- if padding_side == "left":
435
- tensor[i, -len(item) :, :] = item.clone()
436
- else:
437
- tensor[i, : len(item), :] = item.clone()
438
- padding_length.append(tensor.shape[-1] - len(item))
439
-
440
- return tensor, padding_length
441
-
442
-
443
- class MelSpectrogramFeatures(torch.nn.Module):
444
- def __init__(
445
- self,
446
- sample_rate=24000,
447
- n_fft=1024,
448
- hop_length=256,
449
- n_mels=100,
450
- padding: Literal["center", "same"] = "center",
451
- ):
452
- super().__init__()
453
- if padding not in ["center", "same"]:
454
- raise ValueError("Padding must be 'center' or 'same'.")
455
- self.padding = padding
456
- self.mel_spec = torchaudio.transforms.MelSpectrogram(
457
- sample_rate=sample_rate,
458
- n_fft=n_fft,
459
- hop_length=hop_length,
460
- n_mels=n_mels,
461
- center=padding == "center",
462
- power=1,
463
- )
464
-
465
- def __call__(self, audio: torch.Tensor) -> torch.Tensor:
466
- """
467
- audio: Tensor([num_channels, num_samples])
468
- """
469
- return super().__call__(audio)
470
-
471
- def forward(self, audio: torch.Tensor) -> torch.Tensor:
472
- """
473
- audio: Tensor([num_channels, num_samples])
474
- """
475
- mel: torch.Tensor = self.mel_spec(audio)
476
- features = torch.log(torch.clip(mel, min=1e-5))
477
- return features
478
-
479
-
480
- class ChatTTSProcessor:
481
- def __init__(self, text_tokenizer):
482
- self.audio_processor = MelSpectrogramFeatures()
483
- self.text_tokenizer = text_tokenizer
484
-
485
- def __call__(self, text_list, audio_list):
486
- assert len(text_list) == len(audio_list)
487
- input_ids_varlen = []
488
- for text in text_list:
489
- input_ids_ = self.text_tokenizer.encode(text, return_tensors="pt", add_special_tokens=False) # [1, seq_len]
490
- input_ids_ = input_ids_.squeeze(0) # [seq_len]
491
- input_ids_varlen.append(input_ids_)
492
-
493
- audio_features_varlen = []
494
- for audio in audio_list:
495
- assert audio.shape.__len__() == 1 # [seq_len]
496
- try:
497
- mel = self.audio_processor(audio) # [100(num_mel_bins), seq_len_mel]
498
- except Exception as e:
499
- raise e
500
- audio_features_varlen.append(mel)
501
-
502
- return {
503
- "tts_input_ids_varlen": input_ids_varlen, # return List[Tensor]
504
- "tts_input_features_varlen": audio_features_varlen, # return List[Tensor]
505
- }