Upload folder
Browse files- __init__.py +0 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/configuration_qualityv.cpython-310.pyc +0 -0
- __pycache__/modeling_qualityv.cpython-310.pyc +0 -0
- __pycache__/processing_qualityv.cpython-310.pyc +0 -0
- configuration_qualityv.py +78 -0
- modeling_qualityv.py +241 -0
- processing_qualityv.py +312 -0
__init__.py
ADDED
|
File without changes
|
__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
__pycache__/configuration_qualityv.cpython-310.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
__pycache__/modeling_qualityv.cpython-310.pyc
ADDED
|
Binary file (8.12 kB). View file
|
|
|
__pycache__/processing_qualityv.cpython-310.pyc
ADDED
|
Binary file (9.89 kB). View file
|
|
|
configuration_qualityv.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from transformers import AutoConfig
|
| 3 |
+
from transformers.activations import ACT2FN
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class QualityLinearAdapterConfig(PretrainedConfig):
|
| 7 |
+
model_type = "QualityvForCausalLM"
|
| 8 |
+
adapter_type = "linear"
|
| 9 |
+
|
| 10 |
+
def __init__(self,
|
| 11 |
+
in_hidden_size: int = 1024,
|
| 12 |
+
num_layers: int = 2,
|
| 13 |
+
intermediate_size: int = 2048,
|
| 14 |
+
out_hidden_size: int = 2028,
|
| 15 |
+
act_fn: str = "gelu",
|
| 16 |
+
**kwargs,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__(**kwargs)
|
| 19 |
+
|
| 20 |
+
self.in_hidden_size = in_hidden_size
|
| 21 |
+
self.num_layers = num_layers
|
| 22 |
+
self.intermediate_size = intermediate_size
|
| 23 |
+
self.out_hidden_size = out_hidden_size
|
| 24 |
+
self.act_fn = act_fn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class QualityvConfig(PretrainedConfig):
|
| 28 |
+
model_type = "QualityvForCausalLM"
|
| 29 |
+
def __init__(self,
|
| 30 |
+
vision_model_name: str=None,
|
| 31 |
+
audio_model_name: str=None,
|
| 32 |
+
llm_model_name: str=None,
|
| 33 |
+
image_token_id: int=None,
|
| 34 |
+
video_token_id: int=None,
|
| 35 |
+
audio_token_id: int=None,
|
| 36 |
+
adapter_type: str="linear",
|
| 37 |
+
num_adapter_layers: int=2,
|
| 38 |
+
**kwargs,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__(**kwargs)
|
| 41 |
+
self.vision_model_name = vision_model_name
|
| 42 |
+
self.audio_model_name = audio_model_name
|
| 43 |
+
self.llm_model_name = llm_model_name
|
| 44 |
+
self.image_token_id = image_token_id
|
| 45 |
+
self.video_token_id = video_token_id
|
| 46 |
+
self.audio_token_id = audio_token_id
|
| 47 |
+
self.adapter_type = adapter_type
|
| 48 |
+
self.num_adapter_layers = num_adapter_layers
|
| 49 |
+
if llm_model_name is not None:
|
| 50 |
+
self.llm_config = AutoConfig.from_pretrained(llm_model_name)
|
| 51 |
+
for key, value in self.llm_config.to_dict().items():
|
| 52 |
+
setattr(self, key, value)
|
| 53 |
+
if vision_model_name is not None:
|
| 54 |
+
self.vision_config = AutoConfig.from_pretrained(vision_model_name)
|
| 55 |
+
self.vision_adapter_config = QualityLinearAdapterConfig(
|
| 56 |
+
in_hidden_size=self.vision_config.hidden_size,
|
| 57 |
+
intermediate_size=self.vision_config.hidden_size * 2,
|
| 58 |
+
out_hidden_size=self.llm_config.hidden_size,
|
| 59 |
+
num_layers=num_adapter_layers,
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
self.vision_config = None
|
| 63 |
+
if audio_model_name is not None:
|
| 64 |
+
self.audio_config = AutoConfig.from_pretrained(audio_model_name)
|
| 65 |
+
self.audio_adapter_config = QualityLinearAdapterConfig(
|
| 66 |
+
in_hidden_size=self.audio_config.hidden_size,
|
| 67 |
+
intermediate_size=self.audio_config.hidden_size * 2,
|
| 68 |
+
out_hidden_size=self.llm_config.hidden_size,
|
| 69 |
+
num_layers=num_adapter_layers,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
self.audio_config = None
|
| 73 |
+
|
| 74 |
+
def get_vocab_size(self):
|
| 75 |
+
return self.llm_config.vocab_size
|
| 76 |
+
|
| 77 |
+
def get_text_config(self, **kwargs):
|
| 78 |
+
return self.llm_config.get_text_config(**kwargs)
|
modeling_qualityv.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModel, AutoModelForCausalLM
|
| 2 |
+
from transformers.activations import ACT2FN
|
| 3 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 4 |
+
from transformers.generation.utils import GenerationMixin
|
| 5 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import RMSNorm
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
from .configuration_qualityv import QualityvConfig, QualityLinearAdapterConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QualityLinearAdapter(nn.Module):
|
| 15 |
+
def __init__(self, config: QualityLinearAdapterConfig):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.config = config
|
| 18 |
+
self.norm = RMSNorm(config.in_hidden_size)
|
| 19 |
+
self.act_fn = ACT2FN[config.act_fn]
|
| 20 |
+
if config.num_layers == 1:
|
| 21 |
+
self.linears = nn.Linear(config.in_hidden_size, config.out_hidden_size)
|
| 22 |
+
else:
|
| 23 |
+
model_list = []
|
| 24 |
+
for _ in range(config.num_layers - 1):
|
| 25 |
+
model_list.append(nn.Linear(config.in_hidden_size, config.intermediate_size))
|
| 26 |
+
model_list.append(self.act_fn)
|
| 27 |
+
model_list.append(nn.Linear(config.intermediate_size, config.out_hidden_size))
|
| 28 |
+
self.linears = nn.Sequential(*model_list)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
x = self.linears(self.norm(x))
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class QualityvForCausalLM(PreTrainedModel, GenerationMixin):
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: QualityvConfig, *args, **kwargs):
|
| 39 |
+
super().__init__(config, *args, **kwargs)
|
| 40 |
+
self.config = config
|
| 41 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(config.llm_model_name)
|
| 42 |
+
if config.vision_config is not None:
|
| 43 |
+
self.vision_model = AutoModel.from_pretrained(config.vision_model_name)
|
| 44 |
+
self.vision_adapter = QualityLinearAdapter(config.vision_adapter_config)
|
| 45 |
+
if config.audio_config is not None:
|
| 46 |
+
self.audio_model = AutoModel.from_pretrained(config.audio_model_name)
|
| 47 |
+
self.audio_adapter = QualityLinearAdapter(config.audio_adapter_config)
|
| 48 |
+
self.decoder_input_ids = torch.tensor([[1, 1,]]) * self.audio_model.config.decoder_start_token_id
|
| 49 |
+
self.post_init()
|
| 50 |
+
|
| 51 |
+
def get_input_embeddings(self):
|
| 52 |
+
return self.llm_model.get_input_embeddings()
|
| 53 |
+
|
| 54 |
+
def set_input_embeddings(self, value):
|
| 55 |
+
self.llm_model.set_input_embeddings(value)
|
| 56 |
+
|
| 57 |
+
def get_output_embeddings(self):
|
| 58 |
+
return self.llm_model.get_output_embeddings()
|
| 59 |
+
|
| 60 |
+
def set_output_embeddings(self, value):
|
| 61 |
+
self.llm_model.set_output_embeddings(value)
|
| 62 |
+
|
| 63 |
+
def set_decoder(self, decoder):
|
| 64 |
+
self.llm_model.set_decoder(decoder)
|
| 65 |
+
|
| 66 |
+
def get_decoder(self):
|
| 67 |
+
return self.llm_model.get_decoder()
|
| 68 |
+
|
| 69 |
+
def get_vision_model(self):
|
| 70 |
+
return self.vision_model
|
| 71 |
+
|
| 72 |
+
def get_audio_model(self):
|
| 73 |
+
return self.audio_model
|
| 74 |
+
|
| 75 |
+
def get_video_features(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
video_embeds = self.vision_model(pixel_values_videos).last_hidden_state
|
| 77 |
+
video_embeds = self.vision_adapter(video_embeds)
|
| 78 |
+
return video_embeds
|
| 79 |
+
|
| 80 |
+
def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
audio_embeds = self.audio_model.encoder(audio_values).last_hidden_state
|
| 82 |
+
audio_embeds = self.audio_adapter(audio_embeds)
|
| 83 |
+
return audio_embeds
|
| 84 |
+
|
| 85 |
+
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
image_embeds = self.vision_model(pixel_values).last_hidden_state
|
| 87 |
+
image_embeds = self.vision_adapter(image_embeds)
|
| 88 |
+
return image_embeds
|
| 89 |
+
|
| 90 |
+
def replace_multi_modal_embeddings(self, multi_modal_embeds: torch.Tensor,
|
| 91 |
+
input_embeds: torch.Tensor,
|
| 92 |
+
input_ids: torch.LongTensor,
|
| 93 |
+
multi_modal_token_id: int,
|
| 94 |
+
note: str="multi_modal"):
|
| 95 |
+
# multi_modal_embeds: batch_size * num_frames, hidden_steps, hidden_size
|
| 96 |
+
# input_embeds: batch_size, seq_length, hidden_size
|
| 97 |
+
# input_ids: batch_size, seq_length
|
| 98 |
+
# multi_modal_token_id: int
|
| 99 |
+
# note: str
|
| 100 |
+
hidden_size = multi_modal_embeds.shape[-1]
|
| 101 |
+
multi_modal_embeds = multi_modal_embeds.view(-1, hidden_size)
|
| 102 |
+
n_modal_tokens = (input_ids == multi_modal_token_id).sum()
|
| 103 |
+
n_modal_embeds = multi_modal_embeds.shape[0]
|
| 104 |
+
if n_modal_tokens != n_modal_embeds:
|
| 105 |
+
raise ValueError(f"The number of {note} tokens ({n_modal_tokens}) does not match the number of {note} embeddings ({n_modal_embeds}).")
|
| 106 |
+
mask = input_ids == multi_modal_token_id
|
| 107 |
+
mask_unsqueezed = mask.unsqueeze(-1)
|
| 108 |
+
mask_expanded = mask_unsqueezed.expand_as(input_embeds)
|
| 109 |
+
video_mask = mask_expanded.to(input_embeds.device)
|
| 110 |
+
multi_modal_embeds = multi_modal_embeds.to(input_embeds.device, dtype=input_embeds.dtype)
|
| 111 |
+
input_embeds = input_embeds.masked_scatter(video_mask, multi_modal_embeds)
|
| 112 |
+
return input_embeds
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def forward(self,
|
| 116 |
+
input_ids: torch.LongTensor = None,
|
| 117 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 118 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 119 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 120 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 121 |
+
labels: Optional[torch.LongTensor] = None,
|
| 122 |
+
use_cache: Optional[bool] = None,
|
| 123 |
+
output_attentions: Optional[bool] = None,
|
| 124 |
+
output_hidden_states: Optional[bool] = None,
|
| 125 |
+
return_dict: Optional[bool] = None,
|
| 126 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 127 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 128 |
+
audio_values: Optional[torch.FloatTensor] = None,
|
| 129 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 130 |
+
**kwargs
|
| 131 |
+
):
|
| 132 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.llm_config.output_attentions
|
| 133 |
+
output_hidden_states = (
|
| 134 |
+
output_hidden_states if output_hidden_states is not None else self.config.llm_config.output_hidden_states
|
| 135 |
+
)
|
| 136 |
+
return_dict = return_dict if return_dict is not None else self.config.llm_config.use_return_dict
|
| 137 |
+
|
| 138 |
+
if inputs_embeds is None:
|
| 139 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 140 |
+
|
| 141 |
+
if pixel_values_videos is not None:
|
| 142 |
+
video_features = self.get_video_features(pixel_values_videos)
|
| 143 |
+
inputs_embeds = self.replace_multi_modal_embeddings(video_features, inputs_embeds, input_ids, self.config.video_token_id, note="video")
|
| 144 |
+
|
| 145 |
+
if pixel_values is not None:
|
| 146 |
+
image_features = self.get_image_features(pixel_values)
|
| 147 |
+
inputs_embeds = self.replace_multi_modal_embeddings(image_features, inputs_embeds, input_ids, self.config.image_token_id, note="image")
|
| 148 |
+
|
| 149 |
+
if audio_values is not None:
|
| 150 |
+
audio_features = self.get_audio_features(audio_values)
|
| 151 |
+
inputs_embeds = self.replace_multi_modal_embeddings(audio_features, inputs_embeds, input_ids, self.config.audio_token_id, note="audio")
|
| 152 |
+
|
| 153 |
+
outputs = self.llm_model(
|
| 154 |
+
input_ids=None,
|
| 155 |
+
attention_mask=attention_mask,
|
| 156 |
+
position_ids=position_ids,
|
| 157 |
+
past_key_values=past_key_values,
|
| 158 |
+
inputs_embeds=inputs_embeds,
|
| 159 |
+
labels=labels,
|
| 160 |
+
use_cache=use_cache,
|
| 161 |
+
output_attentions=output_attentions,
|
| 162 |
+
output_hidden_states=output_hidden_states,
|
| 163 |
+
return_dict=return_dict,
|
| 164 |
+
cache_position=cache_position,
|
| 165 |
+
**kwargs
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return outputs
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def prepare_inputs_for_generation(self,
|
| 173 |
+
input_ids,
|
| 174 |
+
past_key_values=None,
|
| 175 |
+
attention_mask=None,
|
| 176 |
+
use_cache=None,
|
| 177 |
+
pixel_values=None,
|
| 178 |
+
pixel_values_videos=None,
|
| 179 |
+
audio_values=None,
|
| 180 |
+
cache_position=None,
|
| 181 |
+
**kwargs):
|
| 182 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 183 |
+
input_ids=input_ids,
|
| 184 |
+
past_key_values=past_key_values,
|
| 185 |
+
attention_mask=attention_mask,
|
| 186 |
+
use_cache=use_cache,
|
| 187 |
+
pixel_values=pixel_values,
|
| 188 |
+
pixel_values_videos=pixel_values_videos,
|
| 189 |
+
audio_values=audio_values,
|
| 190 |
+
**kwargs
|
| 191 |
+
)
|
| 192 |
+
if cache_position[0] != 0:
|
| 193 |
+
model_inputs["pixel_values"] = None
|
| 194 |
+
model_inputs["pixel_values_videos"] = None
|
| 195 |
+
return model_inputs
|
| 196 |
+
|
| 197 |
+
def _expand_inputs_for_generation(self,
|
| 198 |
+
expand_size: int = 1,
|
| 199 |
+
is_encoder_decoder: bool = False,
|
| 200 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 201 |
+
**model_kwargs,
|
| 202 |
+
):
|
| 203 |
+
"""Expands input tensors for generation when using beam search or sampling.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
expand_size (int, optional): The size to expand the inputs by. Defaults to 1.
|
| 207 |
+
is_encoder_decoder (bool, optional): Whether the model is an encoder-decoder model. Defaults to False.
|
| 208 |
+
input_ids (Optional[torch.LongTensor], optional): The input token IDs. Defaults to None.
|
| 209 |
+
**model_kwargs: Additional model-specific keyword arguments.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: The expanded input_ids and model_kwargs.
|
| 213 |
+
"""
|
| 214 |
+
if input_ids is not None:
|
| 215 |
+
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
| 216 |
+
|
| 217 |
+
# Expand attention mask if present
|
| 218 |
+
if "attention_mask" in model_kwargs:
|
| 219 |
+
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0)
|
| 220 |
+
|
| 221 |
+
# Expand position IDs if present
|
| 222 |
+
if "position_ids" in model_kwargs:
|
| 223 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"].repeat_interleave(expand_size, dim=0)
|
| 224 |
+
|
| 225 |
+
# Expand pixel values for images if present
|
| 226 |
+
if "pixel_values" in model_kwargs and model_kwargs["pixel_values"] is not None:
|
| 227 |
+
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].repeat_interleave(expand_size, dim=0)
|
| 228 |
+
|
| 229 |
+
# Expand pixel values for videos if present
|
| 230 |
+
if "pixel_values_videos" in model_kwargs and model_kwargs["pixel_values_videos"] is not None:
|
| 231 |
+
model_kwargs["pixel_values_videos"] = model_kwargs["pixel_values_videos"].repeat_interleave(expand_size, dim=0)
|
| 232 |
+
|
| 233 |
+
# Expand audio values if present
|
| 234 |
+
if "audio_values" in model_kwargs and model_kwargs["audio_values"] is not None:
|
| 235 |
+
model_kwargs["audio_values"] = model_kwargs["audio_values"].repeat_interleave(expand_size, dim=0)
|
| 236 |
+
|
| 237 |
+
# Expand cache position if present
|
| 238 |
+
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
| 239 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"].repeat_interleave(expand_size, dim=0)
|
| 240 |
+
|
| 241 |
+
return input_ids, model_kwargs
|
processing_qualityv.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Optional, List, Dict, Tuple, Callable
|
| 2 |
+
from transformers.processing_utils import (ProcessorMixin,
|
| 3 |
+
VideosKwargs,
|
| 4 |
+
AudioKwargs,
|
| 5 |
+
ImagesKwargs,
|
| 6 |
+
TextKwargs,
|
| 7 |
+
ProcessingKwargs,
|
| 8 |
+
Unpack)
|
| 9 |
+
import numpy as np
|
| 10 |
+
import decord
|
| 11 |
+
import torch
|
| 12 |
+
import PIL
|
| 13 |
+
from transformers.audio_utils import load_audio
|
| 14 |
+
from transformers.image_utils import load_image, load_video
|
| 15 |
+
from transformers import AutoImageProcessor, AutoFeatureExtractor, AutoTokenizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_audio_str(audio_path_or_url: str, sampling_rate: int = 16000) -> np.ndarray:
|
| 19 |
+
audio = load_audio(audio_path_or_url, sampling_rate=sampling_rate)
|
| 20 |
+
return audio
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_video_str(video_path_or_url: str, num_frames: int = 4, fps: int = None) -> List[np.ndarray]:
|
| 24 |
+
video = load_video(video_path_or_url, num_frames=num_frames, fps=fps,
|
| 25 |
+
backend="decord")
|
| 26 |
+
return video
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_image_str(image_path_or_url: str) -> List[np.ndarray]:
|
| 30 |
+
image = load_image(image_path_or_url)
|
| 31 |
+
return image
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
ImageInput = Union[
|
| 35 |
+
# same as transformers.image_utils.ImageInput
|
| 36 |
+
"PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"],
|
| 37 |
+
# image urls, or image_paths
|
| 38 |
+
str, list[str]
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
VideoInput = Union[
|
| 43 |
+
# same as transformers.image_utils.VideoInput
|
| 44 |
+
list["PIL.Image.Image"], "np.ndarray", "torch.Tensor", list["np.ndarray"],
|
| 45 |
+
list["torch.Tensor"], list[list["PIL.Image.Image"]], list[list["np.ndarray"]],
|
| 46 |
+
list[list["torch.Tensor"]],
|
| 47 |
+
# video urls, or video_paths
|
| 48 |
+
str, list[str], list[list[str]]
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
AudioInput = Union[
|
| 53 |
+
# same as transformers.audio_utils.AudioInput
|
| 54 |
+
np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"], # noqa: F821
|
| 55 |
+
# audio urls, or audio_paths
|
| 56 |
+
str, list[str]
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class QualityvImageKwargs(ImagesKwargs):
|
| 61 |
+
tokens_per_image: int = 197
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class QualityvVideoKwargs(VideosKwargs):
|
| 65 |
+
num_frames: Union[int, None] = 4
|
| 66 |
+
fps: Union[int, None] = None
|
| 67 |
+
tokens_per_frame: int = 197
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class QualityvAudioKwargs(AudioKwargs):
|
| 71 |
+
sampling_rate: Union[int, None] = 16000
|
| 72 |
+
tokens_per_audio: int = 1500
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class QualityvProcessingKwargs(ProcessingKwargs):
|
| 76 |
+
images_kwargs: QualityvImageKwargs
|
| 77 |
+
videos_kwargs: QualityvVideoKwargs
|
| 78 |
+
audio_kwargs: QualityvAudioKwargs
|
| 79 |
+
text_kwargs: TextKwargs
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class QualityvProcessor(ProcessorMixin):
|
| 83 |
+
|
| 84 |
+
attributes = ["image_processor",
|
| 85 |
+
"audio_processor",
|
| 86 |
+
"tokenizer"]
|
| 87 |
+
image_processor_class = "AutoImageProcessor"
|
| 88 |
+
audio_processor_class = "AutoFeatureExtractor"
|
| 89 |
+
tokenizer_class = "AutoTokenizer"
|
| 90 |
+
|
| 91 |
+
chat_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
|
| 92 |
+
You are a helpful assistant.<|im_end|>
|
| 93 |
+
{% endif %}<|im_start|>{{ message['role'] }}
|
| 94 |
+
{% if message['content'] is string %}{{ message['content'] }}<|im_end|>
|
| 95 |
+
{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'audio' or 'audio' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_vision_id %}Audio {{ audio_count.value }}: {% endif %}<|vision_start|><|audio_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
|
| 96 |
+
{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
|
| 97 |
+
{% endif %}"""
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def __init__(self, tokenizer=None,
|
| 101 |
+
image_processor=None,
|
| 102 |
+
audio_processor=None,
|
| 103 |
+
chat_template=None,
|
| 104 |
+
image_token="<|image_pad|>",
|
| 105 |
+
video_token="<|video_pad|>",
|
| 106 |
+
audio_token="<|audio_pad|>",
|
| 107 |
+
label_start_text="<|im_start|>assistant\n",
|
| 108 |
+
label_end_text="<|im_end|>\n",
|
| 109 |
+
**kwargs):
|
| 110 |
+
self.image_token = image_token if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
| 111 |
+
self.video_token = video_token if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
| 112 |
+
self.audio_token = audio_token if not hasattr(tokenizer, "audio_token") else tokenizer.audio_token
|
| 113 |
+
self.label_start_text = label_start_text
|
| 114 |
+
self.label_end_text = label_end_text
|
| 115 |
+
self.image_token_id = (
|
| 116 |
+
tokenizer.image_token_id
|
| 117 |
+
if getattr(tokenizer, "image_token_id", None)
|
| 118 |
+
else tokenizer.convert_tokens_to_ids(self.image_token)
|
| 119 |
+
)
|
| 120 |
+
self.video_token_id = (
|
| 121 |
+
tokenizer.video_token_id
|
| 122 |
+
if getattr(tokenizer, "video_token_id", None)
|
| 123 |
+
else tokenizer.convert_tokens_to_ids(self.video_token)
|
| 124 |
+
)
|
| 125 |
+
self.audio_token_id = (
|
| 126 |
+
tokenizer.audio_token_id
|
| 127 |
+
if getattr(tokenizer, "audio_token_id", None)
|
| 128 |
+
else tokenizer.convert_tokens_to_ids(self.audio_token)
|
| 129 |
+
)
|
| 130 |
+
if chat_template is None:
|
| 131 |
+
chat_template = self.chat_template
|
| 132 |
+
super().__init__(image_processor, audio_processor, tokenizer,
|
| 133 |
+
chat_template=chat_template)
|
| 134 |
+
|
| 135 |
+
def __call__(self,
|
| 136 |
+
text: Union[str, List[str], None] = None,
|
| 137 |
+
messages: Union[List[Dict], None] = None,
|
| 138 |
+
images: Union[ImageInput, None] = None,
|
| 139 |
+
videos: Union[VideoInput, None] = None,
|
| 140 |
+
audio: Union[AudioInput, None] = None,
|
| 141 |
+
do_train: bool = False,
|
| 142 |
+
add_generation_prompt: bool = False,
|
| 143 |
+
**kwargs: Unpack[QualityvProcessingKwargs]
|
| 144 |
+
):
|
| 145 |
+
'''
|
| 146 |
+
input
|
| 147 |
+
messages: list of dicts
|
| 148 |
+
example:
|
| 149 |
+
[
|
| 150 |
+
{"role": "user"
|
| 151 |
+
"content": [
|
| 152 |
+
{"type": "text", "text": "Hello, how are you?"},
|
| 153 |
+
{"type": "image", "image":xxx)},
|
| 154 |
+
{"type": "video", "video": xxx},
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
...
|
| 158 |
+
]
|
| 159 |
+
output:
|
| 160 |
+
input_ids
|
| 161 |
+
attention_mask
|
| 162 |
+
pixel_values,
|
| 163 |
+
pixel_values_videos
|
| 164 |
+
audio_values
|
| 165 |
+
labels, default None,
|
| 166 |
+
'''
|
| 167 |
+
input_ids = []
|
| 168 |
+
pixel_values = []
|
| 169 |
+
pixel_values_videos = []
|
| 170 |
+
audio_values = []
|
| 171 |
+
labels = None
|
| 172 |
+
|
| 173 |
+
if not text and not messages:
|
| 174 |
+
raise ValueError("At least one of text or messages must be provided.")
|
| 175 |
+
if messages:
|
| 176 |
+
text = self.apply_chat_template(messages, add_generation_prompt=add_generation_prompt,
|
| 177 |
+
tokenize=False)
|
| 178 |
+
if isinstance(text, list):
|
| 179 |
+
text = text[0]
|
| 180 |
+
image_list = self.fill_modal_list(self.image_token, "image", messages, images, text)
|
| 181 |
+
image_list = self.process_str_in_modal_list(image_list, "image", **kwargs.get("images_kwargs", {}))
|
| 182 |
+
# replace image_token with num_images * num_image_token * image_token
|
| 183 |
+
if image_list and self.image_token in text:
|
| 184 |
+
tokens_per_image = kwargs.get("images_kwargs", {}).get("tokens_per_image", 197)
|
| 185 |
+
text = text.replace(self.image_token, tokens_per_image * self.image_token)
|
| 186 |
+
pixel_values = self.image_processor(images=image_list, return_tensors="pt")["pixel_values"]
|
| 187 |
+
|
| 188 |
+
video_list = self.fill_modal_list(self.video_token, "video", messages, videos, text)
|
| 189 |
+
video_list = self.process_str_in_modal_list(video_list, "video", **kwargs.get("videos_kwargs", {}))
|
| 190 |
+
# replace video_token with num_videos * num_video_token * video_token
|
| 191 |
+
if video_list and self.video_token in text:
|
| 192 |
+
tokens_per_frame = kwargs.get("videos_kwargs", {}).get("tokens_per_frame", 197)
|
| 193 |
+
video_frame_list = []
|
| 194 |
+
for video, video_meta in video_list:
|
| 195 |
+
num_frames = video.shape[0]
|
| 196 |
+
replace_text = num_frames * tokens_per_frame * self.video_token
|
| 197 |
+
text = text.replace(self.video_token, replace_text, 1)
|
| 198 |
+
for frame in video:
|
| 199 |
+
video_frame_list.append(frame)
|
| 200 |
+
pixel_values_videos = self.image_processor(images=video_frame_list, return_tensors="pt")["pixel_values"]
|
| 201 |
+
|
| 202 |
+
audio_list = self.fill_modal_list(self.audio_token, "audio", messages, audio, text)
|
| 203 |
+
audio_list = self.process_str_in_modal_list(audio_list, "audio", **kwargs.get("audio_kwargs", {}))
|
| 204 |
+
# replace audio_token with num_audio_tokens * audio_token
|
| 205 |
+
if audio_list and self.audio_token in text:
|
| 206 |
+
audio_kwargs = kwargs.get("audio_kwargs", {})
|
| 207 |
+
sampling_rate = audio_kwargs.get("sampling_rate", 16000)
|
| 208 |
+
tokens_per_audio = audio_kwargs.get("tokens_per_audio", 1500)
|
| 209 |
+
for audio in audio_list:
|
| 210 |
+
replace_text = tokens_per_audio * self.audio_token
|
| 211 |
+
text = text.replace(self.audio_token, replace_text, 1)
|
| 212 |
+
audio_values = self.audio_processor(audio_list, return_tensors="pt", sampling_rate=sampling_rate)["input_features"]
|
| 213 |
+
|
| 214 |
+
input_ids = self.tokenizer(text).input_ids
|
| 215 |
+
if do_train:
|
| 216 |
+
labels = self.get_labels(input_ids)
|
| 217 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
| 218 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 219 |
+
return {
|
| 220 |
+
"input_ids": input_ids,
|
| 221 |
+
"pixel_values": pixel_values if len(pixel_values) > 0 else None,
|
| 222 |
+
"pixel_values_videos": pixel_values_videos if len(pixel_values_videos) > 0 else None,
|
| 223 |
+
"audio_values": audio_values if len(audio_values) > 0 else None,
|
| 224 |
+
"labels": labels
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def fill_modal_list(self, modal_token: str, model_type: str, messages: List[Dict], modal_values: Union[AudioInput, VideoInput, ImageInput, None], text: str) -> List[Union[AudioInput, VideoInput, ImageInput]]:
|
| 228 |
+
modal_list = []
|
| 229 |
+
if modal_token in text:
|
| 230 |
+
if not modal_values and messages:
|
| 231 |
+
for msg in messages:
|
| 232 |
+
if msg.get("role") == "user":
|
| 233 |
+
for content in msg.get("content", []):
|
| 234 |
+
if content.get('type') == model_type:
|
| 235 |
+
modal_list.append(content.get(model_type))
|
| 236 |
+
elif modal_values:
|
| 237 |
+
if isinstance(modal_values, str):
|
| 238 |
+
modal_list = [modal_values]
|
| 239 |
+
else:
|
| 240 |
+
modal_list = modal_values
|
| 241 |
+
return modal_list
|
| 242 |
+
|
| 243 |
+
def process_str_in_modal_list(self, modal_list: list, modal_type: str, **modal_kwargs: dict):
|
| 244 |
+
new_modal_list = []
|
| 245 |
+
if modal_list:
|
| 246 |
+
for modal_value in modal_list:
|
| 247 |
+
if isinstance(modal_value, str):
|
| 248 |
+
new_modal_value = self.load_modal_str(modal_value, modal_type, **modal_kwargs)
|
| 249 |
+
new_modal_list.append(new_modal_value)
|
| 250 |
+
else:
|
| 251 |
+
new_modal_list.append(modal_value)
|
| 252 |
+
return new_modal_list
|
| 253 |
+
|
| 254 |
+
def load_modal_str(self, model_path_or_url: str, modal_type: str, **modal_kwargs):
|
| 255 |
+
if modal_type == "image":
|
| 256 |
+
load_func = load_image_str
|
| 257 |
+
elif modal_type == "video":
|
| 258 |
+
load_func = load_video_str
|
| 259 |
+
elif modal_type == "audio":
|
| 260 |
+
load_func = load_audio_str
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Invalid modal type: {modal_type}")
|
| 263 |
+
return load_func(model_path_or_url, **modal_kwargs)
|
| 264 |
+
|
| 265 |
+
def get_labels(self, input_ids: List[int]) -> List[int]:
|
| 266 |
+
label_start_token_ids = self.tokenizer(self.label_start_text, add_special_tokens=False)["input_ids"]
|
| 267 |
+
label_end_token_ids = self.tokenizer(self.label_end_text, add_special_tokens=False)["input_ids"]
|
| 268 |
+
|
| 269 |
+
labels = [-100] * len(input_ids)
|
| 270 |
+
|
| 271 |
+
i = 0
|
| 272 |
+
while i < len(input_ids):
|
| 273 |
+
# Look for the assistant's response start marker.
|
| 274 |
+
if input_ids[i : i + len(label_start_token_ids)] == label_start_token_ids:
|
| 275 |
+
# The actual response begins after the start marker.
|
| 276 |
+
start_response = i + len(label_start_token_ids)
|
| 277 |
+
# Now, search for the end marker.
|
| 278 |
+
j = start_response
|
| 279 |
+
found_end = False
|
| 280 |
+
while j < len(input_ids):
|
| 281 |
+
if input_ids[j : j + len(label_end_token_ids)] == label_end_token_ids:
|
| 282 |
+
end_response = j + len(label_end_token_ids) # Mark the end of the response (excluding the end marker)
|
| 283 |
+
found_end = True
|
| 284 |
+
break
|
| 285 |
+
j += 1
|
| 286 |
+
|
| 287 |
+
if found_end:
|
| 288 |
+
# Copy the tokens corresponding to the assistant's response into labels.
|
| 289 |
+
labels[start_response:end_response] = input_ids[start_response:end_response]
|
| 290 |
+
# Advance i beyond the end marker.
|
| 291 |
+
i = end_response
|
| 292 |
+
continue # Continue scanning for the next assistant response.
|
| 293 |
+
else:
|
| 294 |
+
# If no end marker is found, break out of the loop.
|
| 295 |
+
break
|
| 296 |
+
else:
|
| 297 |
+
i += 1
|
| 298 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 299 |
+
if pad_token_id is not None:
|
| 300 |
+
for i in range(len(labels)):
|
| 301 |
+
if labels[i] == pad_token_id:
|
| 302 |
+
labels[i] = -100
|
| 303 |
+
return labels
|
| 304 |
+
|
| 305 |
+
def decode(self, *args, **kwargs):
|
| 306 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 307 |
+
|
| 308 |
+
def batch_decode(self, *args, **kwargs):
|
| 309 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
|