CyberBoyNull commited on
Commit
cb65f9f
·
verified ·
1 Parent(s): c02e7ae

Upload folder

Browse files
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
__pycache__/configuration_qualityv.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
__pycache__/modeling_qualityv.cpython-310.pyc ADDED
Binary file (8.12 kB). View file
 
__pycache__/processing_qualityv.cpython-310.pyc ADDED
Binary file (9.89 kB). View file
 
configuration_qualityv.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers import AutoConfig
3
+ from transformers.activations import ACT2FN
4
+
5
+
6
+ class QualityLinearAdapterConfig(PretrainedConfig):
7
+ model_type = "QualityvForCausalLM"
8
+ adapter_type = "linear"
9
+
10
+ def __init__(self,
11
+ in_hidden_size: int = 1024,
12
+ num_layers: int = 2,
13
+ intermediate_size: int = 2048,
14
+ out_hidden_size: int = 2028,
15
+ act_fn: str = "gelu",
16
+ **kwargs,
17
+ ) -> None:
18
+ super().__init__(**kwargs)
19
+
20
+ self.in_hidden_size = in_hidden_size
21
+ self.num_layers = num_layers
22
+ self.intermediate_size = intermediate_size
23
+ self.out_hidden_size = out_hidden_size
24
+ self.act_fn = act_fn
25
+
26
+
27
+ class QualityvConfig(PretrainedConfig):
28
+ model_type = "QualityvForCausalLM"
29
+ def __init__(self,
30
+ vision_model_name: str=None,
31
+ audio_model_name: str=None,
32
+ llm_model_name: str=None,
33
+ image_token_id: int=None,
34
+ video_token_id: int=None,
35
+ audio_token_id: int=None,
36
+ adapter_type: str="linear",
37
+ num_adapter_layers: int=2,
38
+ **kwargs,
39
+ ) -> None:
40
+ super().__init__(**kwargs)
41
+ self.vision_model_name = vision_model_name
42
+ self.audio_model_name = audio_model_name
43
+ self.llm_model_name = llm_model_name
44
+ self.image_token_id = image_token_id
45
+ self.video_token_id = video_token_id
46
+ self.audio_token_id = audio_token_id
47
+ self.adapter_type = adapter_type
48
+ self.num_adapter_layers = num_adapter_layers
49
+ if llm_model_name is not None:
50
+ self.llm_config = AutoConfig.from_pretrained(llm_model_name)
51
+ for key, value in self.llm_config.to_dict().items():
52
+ setattr(self, key, value)
53
+ if vision_model_name is not None:
54
+ self.vision_config = AutoConfig.from_pretrained(vision_model_name)
55
+ self.vision_adapter_config = QualityLinearAdapterConfig(
56
+ in_hidden_size=self.vision_config.hidden_size,
57
+ intermediate_size=self.vision_config.hidden_size * 2,
58
+ out_hidden_size=self.llm_config.hidden_size,
59
+ num_layers=num_adapter_layers,
60
+ )
61
+ else:
62
+ self.vision_config = None
63
+ if audio_model_name is not None:
64
+ self.audio_config = AutoConfig.from_pretrained(audio_model_name)
65
+ self.audio_adapter_config = QualityLinearAdapterConfig(
66
+ in_hidden_size=self.audio_config.hidden_size,
67
+ intermediate_size=self.audio_config.hidden_size * 2,
68
+ out_hidden_size=self.llm_config.hidden_size,
69
+ num_layers=num_adapter_layers,
70
+ )
71
+ else:
72
+ self.audio_config = None
73
+
74
+ def get_vocab_size(self):
75
+ return self.llm_config.vocab_size
76
+
77
+ def get_text_config(self, **kwargs):
78
+ return self.llm_config.get_text_config(**kwargs)
modeling_qualityv.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoModelForCausalLM
2
+ from transformers.activations import ACT2FN
3
+ from transformers.modeling_utils import PreTrainedModel
4
+ from transformers.generation.utils import GenerationMixin
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import RMSNorm
9
+ from typing import List, Optional
10
+
11
+ from .configuration_qualityv import QualityvConfig, QualityLinearAdapterConfig
12
+
13
+
14
+ class QualityLinearAdapter(nn.Module):
15
+ def __init__(self, config: QualityLinearAdapterConfig):
16
+ super().__init__()
17
+ self.config = config
18
+ self.norm = RMSNorm(config.in_hidden_size)
19
+ self.act_fn = ACT2FN[config.act_fn]
20
+ if config.num_layers == 1:
21
+ self.linears = nn.Linear(config.in_hidden_size, config.out_hidden_size)
22
+ else:
23
+ model_list = []
24
+ for _ in range(config.num_layers - 1):
25
+ model_list.append(nn.Linear(config.in_hidden_size, config.intermediate_size))
26
+ model_list.append(self.act_fn)
27
+ model_list.append(nn.Linear(config.intermediate_size, config.out_hidden_size))
28
+ self.linears = nn.Sequential(*model_list)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ x = self.linears(self.norm(x))
32
+ return x
33
+
34
+
35
+ class QualityvForCausalLM(PreTrainedModel, GenerationMixin):
36
+
37
+
38
+ def __init__(self, config: QualityvConfig, *args, **kwargs):
39
+ super().__init__(config, *args, **kwargs)
40
+ self.config = config
41
+ self.llm_model = AutoModelForCausalLM.from_pretrained(config.llm_model_name)
42
+ if config.vision_config is not None:
43
+ self.vision_model = AutoModel.from_pretrained(config.vision_model_name)
44
+ self.vision_adapter = QualityLinearAdapter(config.vision_adapter_config)
45
+ if config.audio_config is not None:
46
+ self.audio_model = AutoModel.from_pretrained(config.audio_model_name)
47
+ self.audio_adapter = QualityLinearAdapter(config.audio_adapter_config)
48
+ self.decoder_input_ids = torch.tensor([[1, 1,]]) * self.audio_model.config.decoder_start_token_id
49
+ self.post_init()
50
+
51
+ def get_input_embeddings(self):
52
+ return self.llm_model.get_input_embeddings()
53
+
54
+ def set_input_embeddings(self, value):
55
+ self.llm_model.set_input_embeddings(value)
56
+
57
+ def get_output_embeddings(self):
58
+ return self.llm_model.get_output_embeddings()
59
+
60
+ def set_output_embeddings(self, value):
61
+ self.llm_model.set_output_embeddings(value)
62
+
63
+ def set_decoder(self, decoder):
64
+ self.llm_model.set_decoder(decoder)
65
+
66
+ def get_decoder(self):
67
+ return self.llm_model.get_decoder()
68
+
69
+ def get_vision_model(self):
70
+ return self.vision_model
71
+
72
+ def get_audio_model(self):
73
+ return self.audio_model
74
+
75
+ def get_video_features(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
76
+ video_embeds = self.vision_model(pixel_values_videos).last_hidden_state
77
+ video_embeds = self.vision_adapter(video_embeds)
78
+ return video_embeds
79
+
80
+ def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor:
81
+ audio_embeds = self.audio_model.encoder(audio_values).last_hidden_state
82
+ audio_embeds = self.audio_adapter(audio_embeds)
83
+ return audio_embeds
84
+
85
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
86
+ image_embeds = self.vision_model(pixel_values).last_hidden_state
87
+ image_embeds = self.vision_adapter(image_embeds)
88
+ return image_embeds
89
+
90
+ def replace_multi_modal_embeddings(self, multi_modal_embeds: torch.Tensor,
91
+ input_embeds: torch.Tensor,
92
+ input_ids: torch.LongTensor,
93
+ multi_modal_token_id: int,
94
+ note: str="multi_modal"):
95
+ # multi_modal_embeds: batch_size * num_frames, hidden_steps, hidden_size
96
+ # input_embeds: batch_size, seq_length, hidden_size
97
+ # input_ids: batch_size, seq_length
98
+ # multi_modal_token_id: int
99
+ # note: str
100
+ hidden_size = multi_modal_embeds.shape[-1]
101
+ multi_modal_embeds = multi_modal_embeds.view(-1, hidden_size)
102
+ n_modal_tokens = (input_ids == multi_modal_token_id).sum()
103
+ n_modal_embeds = multi_modal_embeds.shape[0]
104
+ if n_modal_tokens != n_modal_embeds:
105
+ raise ValueError(f"The number of {note} tokens ({n_modal_tokens}) does not match the number of {note} embeddings ({n_modal_embeds}).")
106
+ mask = input_ids == multi_modal_token_id
107
+ mask_unsqueezed = mask.unsqueeze(-1)
108
+ mask_expanded = mask_unsqueezed.expand_as(input_embeds)
109
+ video_mask = mask_expanded.to(input_embeds.device)
110
+ multi_modal_embeds = multi_modal_embeds.to(input_embeds.device, dtype=input_embeds.dtype)
111
+ input_embeds = input_embeds.masked_scatter(video_mask, multi_modal_embeds)
112
+ return input_embeds
113
+
114
+
115
+ def forward(self,
116
+ input_ids: torch.LongTensor = None,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ position_ids: Optional[torch.LongTensor] = None,
119
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
120
+ inputs_embeds: Optional[torch.FloatTensor] = None,
121
+ labels: Optional[torch.LongTensor] = None,
122
+ use_cache: Optional[bool] = None,
123
+ output_attentions: Optional[bool] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ pixel_values: Optional[torch.Tensor] = None,
127
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
128
+ audio_values: Optional[torch.FloatTensor] = None,
129
+ cache_position: Optional[torch.LongTensor] = None,
130
+ **kwargs
131
+ ):
132
+ output_attentions = output_attentions if output_attentions is not None else self.config.llm_config.output_attentions
133
+ output_hidden_states = (
134
+ output_hidden_states if output_hidden_states is not None else self.config.llm_config.output_hidden_states
135
+ )
136
+ return_dict = return_dict if return_dict is not None else self.config.llm_config.use_return_dict
137
+
138
+ if inputs_embeds is None:
139
+ inputs_embeds = self.get_input_embeddings()(input_ids)
140
+
141
+ if pixel_values_videos is not None:
142
+ video_features = self.get_video_features(pixel_values_videos)
143
+ inputs_embeds = self.replace_multi_modal_embeddings(video_features, inputs_embeds, input_ids, self.config.video_token_id, note="video")
144
+
145
+ if pixel_values is not None:
146
+ image_features = self.get_image_features(pixel_values)
147
+ inputs_embeds = self.replace_multi_modal_embeddings(image_features, inputs_embeds, input_ids, self.config.image_token_id, note="image")
148
+
149
+ if audio_values is not None:
150
+ audio_features = self.get_audio_features(audio_values)
151
+ inputs_embeds = self.replace_multi_modal_embeddings(audio_features, inputs_embeds, input_ids, self.config.audio_token_id, note="audio")
152
+
153
+ outputs = self.llm_model(
154
+ input_ids=None,
155
+ attention_mask=attention_mask,
156
+ position_ids=position_ids,
157
+ past_key_values=past_key_values,
158
+ inputs_embeds=inputs_embeds,
159
+ labels=labels,
160
+ use_cache=use_cache,
161
+ output_attentions=output_attentions,
162
+ output_hidden_states=output_hidden_states,
163
+ return_dict=return_dict,
164
+ cache_position=cache_position,
165
+ **kwargs
166
+ )
167
+
168
+ return outputs
169
+
170
+
171
+
172
+ def prepare_inputs_for_generation(self,
173
+ input_ids,
174
+ past_key_values=None,
175
+ attention_mask=None,
176
+ use_cache=None,
177
+ pixel_values=None,
178
+ pixel_values_videos=None,
179
+ audio_values=None,
180
+ cache_position=None,
181
+ **kwargs):
182
+ model_inputs = super().prepare_inputs_for_generation(
183
+ input_ids=input_ids,
184
+ past_key_values=past_key_values,
185
+ attention_mask=attention_mask,
186
+ use_cache=use_cache,
187
+ pixel_values=pixel_values,
188
+ pixel_values_videos=pixel_values_videos,
189
+ audio_values=audio_values,
190
+ **kwargs
191
+ )
192
+ if cache_position[0] != 0:
193
+ model_inputs["pixel_values"] = None
194
+ model_inputs["pixel_values_videos"] = None
195
+ return model_inputs
196
+
197
+ def _expand_inputs_for_generation(self,
198
+ expand_size: int = 1,
199
+ is_encoder_decoder: bool = False,
200
+ input_ids: Optional[torch.LongTensor] = None,
201
+ **model_kwargs,
202
+ ):
203
+ """Expands input tensors for generation when using beam search or sampling.
204
+
205
+ Args:
206
+ expand_size (int, optional): The size to expand the inputs by. Defaults to 1.
207
+ is_encoder_decoder (bool, optional): Whether the model is an encoder-decoder model. Defaults to False.
208
+ input_ids (Optional[torch.LongTensor], optional): The input token IDs. Defaults to None.
209
+ **model_kwargs: Additional model-specific keyword arguments.
210
+
211
+ Returns:
212
+ Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: The expanded input_ids and model_kwargs.
213
+ """
214
+ if input_ids is not None:
215
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
216
+
217
+ # Expand attention mask if present
218
+ if "attention_mask" in model_kwargs:
219
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0)
220
+
221
+ # Expand position IDs if present
222
+ if "position_ids" in model_kwargs:
223
+ model_kwargs["position_ids"] = model_kwargs["position_ids"].repeat_interleave(expand_size, dim=0)
224
+
225
+ # Expand pixel values for images if present
226
+ if "pixel_values" in model_kwargs and model_kwargs["pixel_values"] is not None:
227
+ model_kwargs["pixel_values"] = model_kwargs["pixel_values"].repeat_interleave(expand_size, dim=0)
228
+
229
+ # Expand pixel values for videos if present
230
+ if "pixel_values_videos" in model_kwargs and model_kwargs["pixel_values_videos"] is not None:
231
+ model_kwargs["pixel_values_videos"] = model_kwargs["pixel_values_videos"].repeat_interleave(expand_size, dim=0)
232
+
233
+ # Expand audio values if present
234
+ if "audio_values" in model_kwargs and model_kwargs["audio_values"] is not None:
235
+ model_kwargs["audio_values"] = model_kwargs["audio_values"].repeat_interleave(expand_size, dim=0)
236
+
237
+ # Expand cache position if present
238
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
239
+ model_kwargs["cache_position"] = model_kwargs["cache_position"].repeat_interleave(expand_size, dim=0)
240
+
241
+ return input_ids, model_kwargs
processing_qualityv.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, List, Dict, Tuple, Callable
2
+ from transformers.processing_utils import (ProcessorMixin,
3
+ VideosKwargs,
4
+ AudioKwargs,
5
+ ImagesKwargs,
6
+ TextKwargs,
7
+ ProcessingKwargs,
8
+ Unpack)
9
+ import numpy as np
10
+ import decord
11
+ import torch
12
+ import PIL
13
+ from transformers.audio_utils import load_audio
14
+ from transformers.image_utils import load_image, load_video
15
+ from transformers import AutoImageProcessor, AutoFeatureExtractor, AutoTokenizer
16
+
17
+
18
+ def load_audio_str(audio_path_or_url: str, sampling_rate: int = 16000) -> np.ndarray:
19
+ audio = load_audio(audio_path_or_url, sampling_rate=sampling_rate)
20
+ return audio
21
+
22
+
23
+ def load_video_str(video_path_or_url: str, num_frames: int = 4, fps: int = None) -> List[np.ndarray]:
24
+ video = load_video(video_path_or_url, num_frames=num_frames, fps=fps,
25
+ backend="decord")
26
+ return video
27
+
28
+
29
+ def load_image_str(image_path_or_url: str) -> List[np.ndarray]:
30
+ image = load_image(image_path_or_url)
31
+ return image
32
+
33
+
34
+ ImageInput = Union[
35
+ # same as transformers.image_utils.ImageInput
36
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"],
37
+ # image urls, or image_paths
38
+ str, list[str]
39
+ ]
40
+
41
+
42
+ VideoInput = Union[
43
+ # same as transformers.image_utils.VideoInput
44
+ list["PIL.Image.Image"], "np.ndarray", "torch.Tensor", list["np.ndarray"],
45
+ list["torch.Tensor"], list[list["PIL.Image.Image"]], list[list["np.ndarray"]],
46
+ list[list["torch.Tensor"]],
47
+ # video urls, or video_paths
48
+ str, list[str], list[list[str]]
49
+ ]
50
+
51
+
52
+ AudioInput = Union[
53
+ # same as transformers.audio_utils.AudioInput
54
+ np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"], # noqa: F821
55
+ # audio urls, or audio_paths
56
+ str, list[str]
57
+ ]
58
+
59
+
60
+ class QualityvImageKwargs(ImagesKwargs):
61
+ tokens_per_image: int = 197
62
+
63
+
64
+ class QualityvVideoKwargs(VideosKwargs):
65
+ num_frames: Union[int, None] = 4
66
+ fps: Union[int, None] = None
67
+ tokens_per_frame: int = 197
68
+
69
+
70
+ class QualityvAudioKwargs(AudioKwargs):
71
+ sampling_rate: Union[int, None] = 16000
72
+ tokens_per_audio: int = 1500
73
+
74
+
75
+ class QualityvProcessingKwargs(ProcessingKwargs):
76
+ images_kwargs: QualityvImageKwargs
77
+ videos_kwargs: QualityvVideoKwargs
78
+ audio_kwargs: QualityvAudioKwargs
79
+ text_kwargs: TextKwargs
80
+
81
+
82
+ class QualityvProcessor(ProcessorMixin):
83
+
84
+ attributes = ["image_processor",
85
+ "audio_processor",
86
+ "tokenizer"]
87
+ image_processor_class = "AutoImageProcessor"
88
+ audio_processor_class = "AutoFeatureExtractor"
89
+ tokenizer_class = "AutoTokenizer"
90
+
91
+ chat_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
92
+ You are a helpful assistant.<|im_end|>
93
+ {% endif %}<|im_start|>{{ message['role'] }}
94
+ {% if message['content'] is string %}{{ message['content'] }}<|im_end|>
95
+ {% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'audio' or 'audio' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_vision_id %}Audio {{ audio_count.value }}: {% endif %}<|vision_start|><|audio_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
96
+ {% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
97
+ {% endif %}"""
98
+
99
+
100
+ def __init__(self, tokenizer=None,
101
+ image_processor=None,
102
+ audio_processor=None,
103
+ chat_template=None,
104
+ image_token="<|image_pad|>",
105
+ video_token="<|video_pad|>",
106
+ audio_token="<|audio_pad|>",
107
+ label_start_text="<|im_start|>assistant\n",
108
+ label_end_text="<|im_end|>\n",
109
+ **kwargs):
110
+ self.image_token = image_token if not hasattr(tokenizer, "image_token") else tokenizer.image_token
111
+ self.video_token = video_token if not hasattr(tokenizer, "video_token") else tokenizer.video_token
112
+ self.audio_token = audio_token if not hasattr(tokenizer, "audio_token") else tokenizer.audio_token
113
+ self.label_start_text = label_start_text
114
+ self.label_end_text = label_end_text
115
+ self.image_token_id = (
116
+ tokenizer.image_token_id
117
+ if getattr(tokenizer, "image_token_id", None)
118
+ else tokenizer.convert_tokens_to_ids(self.image_token)
119
+ )
120
+ self.video_token_id = (
121
+ tokenizer.video_token_id
122
+ if getattr(tokenizer, "video_token_id", None)
123
+ else tokenizer.convert_tokens_to_ids(self.video_token)
124
+ )
125
+ self.audio_token_id = (
126
+ tokenizer.audio_token_id
127
+ if getattr(tokenizer, "audio_token_id", None)
128
+ else tokenizer.convert_tokens_to_ids(self.audio_token)
129
+ )
130
+ if chat_template is None:
131
+ chat_template = self.chat_template
132
+ super().__init__(image_processor, audio_processor, tokenizer,
133
+ chat_template=chat_template)
134
+
135
+ def __call__(self,
136
+ text: Union[str, List[str], None] = None,
137
+ messages: Union[List[Dict], None] = None,
138
+ images: Union[ImageInput, None] = None,
139
+ videos: Union[VideoInput, None] = None,
140
+ audio: Union[AudioInput, None] = None,
141
+ do_train: bool = False,
142
+ add_generation_prompt: bool = False,
143
+ **kwargs: Unpack[QualityvProcessingKwargs]
144
+ ):
145
+ '''
146
+ input
147
+ messages: list of dicts
148
+ example:
149
+ [
150
+ {"role": "user"
151
+ "content": [
152
+ {"type": "text", "text": "Hello, how are you?"},
153
+ {"type": "image", "image":xxx)},
154
+ {"type": "video", "video": xxx},
155
+ ]
156
+ },
157
+ ...
158
+ ]
159
+ output:
160
+ input_ids
161
+ attention_mask
162
+ pixel_values,
163
+ pixel_values_videos
164
+ audio_values
165
+ labels, default None,
166
+ '''
167
+ input_ids = []
168
+ pixel_values = []
169
+ pixel_values_videos = []
170
+ audio_values = []
171
+ labels = None
172
+
173
+ if not text and not messages:
174
+ raise ValueError("At least one of text or messages must be provided.")
175
+ if messages:
176
+ text = self.apply_chat_template(messages, add_generation_prompt=add_generation_prompt,
177
+ tokenize=False)
178
+ if isinstance(text, list):
179
+ text = text[0]
180
+ image_list = self.fill_modal_list(self.image_token, "image", messages, images, text)
181
+ image_list = self.process_str_in_modal_list(image_list, "image", **kwargs.get("images_kwargs", {}))
182
+ # replace image_token with num_images * num_image_token * image_token
183
+ if image_list and self.image_token in text:
184
+ tokens_per_image = kwargs.get("images_kwargs", {}).get("tokens_per_image", 197)
185
+ text = text.replace(self.image_token, tokens_per_image * self.image_token)
186
+ pixel_values = self.image_processor(images=image_list, return_tensors="pt")["pixel_values"]
187
+
188
+ video_list = self.fill_modal_list(self.video_token, "video", messages, videos, text)
189
+ video_list = self.process_str_in_modal_list(video_list, "video", **kwargs.get("videos_kwargs", {}))
190
+ # replace video_token with num_videos * num_video_token * video_token
191
+ if video_list and self.video_token in text:
192
+ tokens_per_frame = kwargs.get("videos_kwargs", {}).get("tokens_per_frame", 197)
193
+ video_frame_list = []
194
+ for video, video_meta in video_list:
195
+ num_frames = video.shape[0]
196
+ replace_text = num_frames * tokens_per_frame * self.video_token
197
+ text = text.replace(self.video_token, replace_text, 1)
198
+ for frame in video:
199
+ video_frame_list.append(frame)
200
+ pixel_values_videos = self.image_processor(images=video_frame_list, return_tensors="pt")["pixel_values"]
201
+
202
+ audio_list = self.fill_modal_list(self.audio_token, "audio", messages, audio, text)
203
+ audio_list = self.process_str_in_modal_list(audio_list, "audio", **kwargs.get("audio_kwargs", {}))
204
+ # replace audio_token with num_audio_tokens * audio_token
205
+ if audio_list and self.audio_token in text:
206
+ audio_kwargs = kwargs.get("audio_kwargs", {})
207
+ sampling_rate = audio_kwargs.get("sampling_rate", 16000)
208
+ tokens_per_audio = audio_kwargs.get("tokens_per_audio", 1500)
209
+ for audio in audio_list:
210
+ replace_text = tokens_per_audio * self.audio_token
211
+ text = text.replace(self.audio_token, replace_text, 1)
212
+ audio_values = self.audio_processor(audio_list, return_tensors="pt", sampling_rate=sampling_rate)["input_features"]
213
+
214
+ input_ids = self.tokenizer(text).input_ids
215
+ if do_train:
216
+ labels = self.get_labels(input_ids)
217
+ labels = torch.tensor(labels, dtype=torch.long)
218
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
219
+ return {
220
+ "input_ids": input_ids,
221
+ "pixel_values": pixel_values if len(pixel_values) > 0 else None,
222
+ "pixel_values_videos": pixel_values_videos if len(pixel_values_videos) > 0 else None,
223
+ "audio_values": audio_values if len(audio_values) > 0 else None,
224
+ "labels": labels
225
+ }
226
+
227
+ def fill_modal_list(self, modal_token: str, model_type: str, messages: List[Dict], modal_values: Union[AudioInput, VideoInput, ImageInput, None], text: str) -> List[Union[AudioInput, VideoInput, ImageInput]]:
228
+ modal_list = []
229
+ if modal_token in text:
230
+ if not modal_values and messages:
231
+ for msg in messages:
232
+ if msg.get("role") == "user":
233
+ for content in msg.get("content", []):
234
+ if content.get('type') == model_type:
235
+ modal_list.append(content.get(model_type))
236
+ elif modal_values:
237
+ if isinstance(modal_values, str):
238
+ modal_list = [modal_values]
239
+ else:
240
+ modal_list = modal_values
241
+ return modal_list
242
+
243
+ def process_str_in_modal_list(self, modal_list: list, modal_type: str, **modal_kwargs: dict):
244
+ new_modal_list = []
245
+ if modal_list:
246
+ for modal_value in modal_list:
247
+ if isinstance(modal_value, str):
248
+ new_modal_value = self.load_modal_str(modal_value, modal_type, **modal_kwargs)
249
+ new_modal_list.append(new_modal_value)
250
+ else:
251
+ new_modal_list.append(modal_value)
252
+ return new_modal_list
253
+
254
+ def load_modal_str(self, model_path_or_url: str, modal_type: str, **modal_kwargs):
255
+ if modal_type == "image":
256
+ load_func = load_image_str
257
+ elif modal_type == "video":
258
+ load_func = load_video_str
259
+ elif modal_type == "audio":
260
+ load_func = load_audio_str
261
+ else:
262
+ raise ValueError(f"Invalid modal type: {modal_type}")
263
+ return load_func(model_path_or_url, **modal_kwargs)
264
+
265
+ def get_labels(self, input_ids: List[int]) -> List[int]:
266
+ label_start_token_ids = self.tokenizer(self.label_start_text, add_special_tokens=False)["input_ids"]
267
+ label_end_token_ids = self.tokenizer(self.label_end_text, add_special_tokens=False)["input_ids"]
268
+
269
+ labels = [-100] * len(input_ids)
270
+
271
+ i = 0
272
+ while i < len(input_ids):
273
+ # Look for the assistant's response start marker.
274
+ if input_ids[i : i + len(label_start_token_ids)] == label_start_token_ids:
275
+ # The actual response begins after the start marker.
276
+ start_response = i + len(label_start_token_ids)
277
+ # Now, search for the end marker.
278
+ j = start_response
279
+ found_end = False
280
+ while j < len(input_ids):
281
+ if input_ids[j : j + len(label_end_token_ids)] == label_end_token_ids:
282
+ end_response = j + len(label_end_token_ids) # Mark the end of the response (excluding the end marker)
283
+ found_end = True
284
+ break
285
+ j += 1
286
+
287
+ if found_end:
288
+ # Copy the tokens corresponding to the assistant's response into labels.
289
+ labels[start_response:end_response] = input_ids[start_response:end_response]
290
+ # Advance i beyond the end marker.
291
+ i = end_response
292
+ continue # Continue scanning for the next assistant response.
293
+ else:
294
+ # If no end marker is found, break out of the loop.
295
+ break
296
+ else:
297
+ i += 1
298
+ pad_token_id = self.tokenizer.pad_token_id
299
+ if pad_token_id is not None:
300
+ for i in range(len(labels)):
301
+ if labels[i] == pad_token_id:
302
+ labels[i] = -100
303
+ return labels
304
+
305
+ def decode(self, *args, **kwargs):
306
+ return self.tokenizer.decode(*args, **kwargs)
307
+
308
+ def batch_decode(self, *args, **kwargs):
309
+ return self.tokenizer.batch_decode(*args, **kwargs)
310
+
311
+
312
+