| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Inference-only IBM Granite speech model.""" |
| | import math |
| | from collections.abc import Iterable, Mapping |
| | from typing import Optional, TypedDict, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from transformers import BatchFeature, PretrainedConfig |
| |
|
| | from vllm.config import CacheConfig, VllmConfig |
| | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| | RowParallelLinear) |
| | from vllm.model_executor.layers.quantization import QuantizationConfig |
| | from vllm.model_executor.layers.sampler import get_sampler |
| | from vllm.model_executor.models.module_mapping import MultiModelKeys |
| | from vllm.model_executor.sampling_metadata import SamplingMetadata |
| | from vllm.multimodal import MULTIMODAL_REGISTRY |
| | from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, |
| | MultiModalKwargs) |
| | from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, |
| | MultiModalDataParser) |
| | from vllm.multimodal.processing import (BaseMultiModalProcessor, |
| | BaseProcessingInfo, PromptReplacement, |
| | PromptUpdate) |
| | from vllm.multimodal.profiling import BaseDummyInputsBuilder |
| | from vllm.sequence import IntermediateTensors |
| |
|
| | from .blip2 import Blip2QFormerModel |
| | from .interfaces import (MultiModalEmbeddings, SupportsLoRA, |
| | SupportsMultiModal, SupportsPP) |
| | from .utils import (AutoWeightsLoader, embed_multimodal, |
| | init_vllm_registered_model, maybe_prefix) |
| |
|
| |
|
| | |
| | class GraniteSpeechAudioInputs(TypedDict): |
| |
|
| | input_features: torch.Tensor |
| | """Shape: `(bsz, num_features, 160)`""" |
| |
|
| | input_features_mask: torch.Tensor |
| | """Shape: `(bsz, num_features)`""" |
| |
|
| | audio_embed_sizes: list[int] |
| | """List of length `bsz`""" |
| |
|
| |
|
| | class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): |
| |
|
| | def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: |
| | return {"audio": 1} |
| |
|
| | |
| | |
| | |
| | |
| | def get_max_audio_tokens(self): |
| | return 5001 |
| |
|
| | def get_max_audio_len(self): |
| | return 8000000 |
| |
|
| |
|
| | |
| | class GraniteSpeechMultiModalProcessor( |
| | BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): |
| |
|
| | def _get_data_parser(self) -> MultiModalDataParser: |
| | feature_extractor = self.info.get_hf_processor().audio_processor |
| | sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] |
| | return MultiModalDataParser(target_sr=sampling_rate) |
| |
|
| | def _get_mm_fields_config( |
| | self, |
| | hf_inputs: BatchFeature, |
| | hf_processor_mm_kwargs: Mapping[str, object], |
| | ) -> Mapping[str, MultiModalFieldConfig]: |
| | return dict( |
| | input_features=MultiModalFieldConfig.batched("audio"), |
| | audio_embed_sizes=MultiModalFieldConfig.batched("audio"), |
| | ) |
| |
|
| | def _get_prompt_updates( |
| | self, |
| | mm_items: MultiModalDataItems, |
| | hf_processor_mm_kwargs: Mapping[str, object], |
| | out_mm_kwargs: MultiModalKwargs, |
| | ) -> list[PromptUpdate]: |
| | processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) |
| | tokenizer = self.info.get_tokenizer() |
| | feature_extractor = processor.audio_processor |
| | vocab = tokenizer.get_vocab() |
| |
|
| | |
| | audio_token = getattr(processor, "audio_token", "<|audio|>") |
| | audio_token_id = vocab[audio_token] |
| |
|
| | def get_replacement(item_idx: int): |
| | audios = mm_items.get_items("audio", AudioProcessorItems) |
| | audio = audios.get(item_idx) |
| | audio_length = audio.shape[-1] |
| | num_projector_features = feature_extractor._get_num_audio_features( |
| | [audio_length])[0] |
| | return [audio_token_id] * num_projector_features |
| |
|
| | return [ |
| | PromptReplacement( |
| | modality="audio", |
| | target=[audio_token_id], |
| | replacement=get_replacement, |
| | ) |
| | ] |
| |
|
| | def _call_hf_processor( |
| | self, |
| | prompt: str, |
| | mm_data: Mapping[str, object], |
| | mm_kwargs: Mapping[str, object], |
| | ) -> BatchFeature: |
| | mm_data = dict(mm_data) |
| | audios = mm_data.pop("audios", []) |
| |
|
| | if audios: |
| | |
| | mm_data["audio"] = audios |
| |
|
| | processed_outputs = super()._call_hf_processor( |
| | prompt=prompt, |
| | mm_data=mm_data, |
| | mm_kwargs=mm_kwargs, |
| | ) |
| |
|
| | if "audio" in mm_data: |
| | |
| | |
| | audio_token_index = self.info.get_hf_config().audio_token_index |
| | processed_outputs["audio_embed_sizes"] = [ |
| | torch.sum(indices == audio_token_index).item() |
| | for indices in processed_outputs["input_ids"] |
| | ] |
| |
|
| | return processed_outputs |
| |
|
| |
|
| | class GraniteSpeechDummyInputsBuilder( |
| | BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): |
| |
|
| | def get_dummy_mm_data( |
| | self, |
| | seq_len: int, |
| | mm_counts: Mapping[str, int], |
| | ) -> MultiModalDataDict: |
| | num_audios = mm_counts.get("audio", 0) |
| | return { |
| | "audio": |
| | self._get_dummy_audios( |
| | length=self.info.get_max_audio_len(), |
| | num_audios=num_audios, |
| | ) |
| | } |
| |
|
| | def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: |
| | num_audios = mm_counts.get("audio", 0) |
| | hf_processor = self.info.get_hf_processor() |
| | audio_token = getattr(hf_processor, "audio_token", "<|audio|>") |
| | return audio_token * num_audios |
| |
|
| |
|
| | |
| | class GraniteSpeechEncoderProjector(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | config: PretrainedConfig, |
| | cache_config: CacheConfig, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = "", |
| | ): |
| | super().__init__() |
| | self.hidden_size = config.projector_config.hidden_size |
| | self.downsample_rate = config.downsample_rate |
| | self.window_size = config.window_size |
| | self.num_queries = config.window_size // config.downsample_rate |
| |
|
| | self.query = nn.Parameter( |
| | torch.zeros(1, self.num_queries, |
| | config.projector_config.hidden_size)) |
| |
|
| | |
| | |
| | |
| | self.qformer = Blip2QFormerModel( |
| | config.projector_config, |
| | quant_config=quant_config, |
| | cache_config=cache_config, |
| | prefix=f"{prefix}.qformer", |
| | ) |
| | self.linear = nn.Linear(config.projector_config.hidden_size, |
| | config.text_config.hidden_size) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | batch_size, seq_len, dim = hidden_states.size() |
| | nblocks = math.ceil(seq_len / self.window_size) |
| | pad = nblocks * self.window_size - seq_len |
| | hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), |
| | "constant", 0) |
| | hidden_states = hidden_states.view(batch_size * nblocks, |
| | self.window_size, dim) |
| |
|
| | last_hidden_state = self.qformer( |
| | query_embeds=self.query.data, |
| | encoder_hidden_states=hidden_states, |
| | ) |
| |
|
| | query_proj = self.linear( |
| | last_hidden_state.view( |
| | batch_size, |
| | nblocks * self.window_size // self.downsample_rate, |
| | -1, |
| | )) |
| | return query_proj |
| |
|
| |
|
| | |
| | |
| | |
| | class GraniteSpeechConformerFeedForward(nn.Module): |
| | """Feedforward module for conformer encoder blocks.""" |
| |
|
| | def __init__(self, |
| | config: PretrainedConfig, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = ""): |
| | super().__init__() |
| | self.pre_norm = nn.LayerNorm(config.hidden_dim) |
| |
|
| | self.up_proj = ColumnParallelLinear( |
| | input_size=config.hidden_dim, |
| | output_size=config.hidden_dim * config.feedforward_mult, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.up_proj", |
| | ) |
| | self.silu = nn.SiLU() |
| |
|
| | self.down_proj = RowParallelLinear( |
| | input_size=config.hidden_dim * config.feedforward_mult, |
| | output_size=config.hidden_dim, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.down_proj", |
| | ) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | hidden_states = self.pre_norm(hidden_states) |
| | hidden_states, _ = self.up_proj(hidden_states) |
| | hidden_states = self.silu(hidden_states) |
| | hidden_states, _ = self.down_proj(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class GraniteSpeechConformerAttention(nn.Module): |
| | """Attention for conformer blocks using Shaw's relative positional |
| | embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155) |
| | for more details. |
| | """ |
| |
|
| | def __init__(self, config: PretrainedConfig, prefix: str = ""): |
| | super().__init__() |
| |
|
| | inner_dim = config.dim_head * config.num_heads |
| | self.max_pos_emb = config.max_pos_emb |
| | self.context_size = config.context_size |
| | self.num_heads = config.num_heads |
| | self.dim_head = config.dim_head |
| | self.scale = self.dim_head**-0.5 |
| | self.pre_norm = nn.LayerNorm(config.hidden_dim) |
| | self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) |
| | self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) |
| | self.to_out = nn.Linear(inner_dim, config.hidden_dim) |
| | self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, |
| | self.dim_head) |
| |
|
| | if self.context_size <= 0 or self.context_size > self.max_pos_emb: |
| | raise ValueError( |
| | "Context size is either less than 0 or exceeds the max_pos_emb" |
| | ) |
| |
|
| | def forward(self, hidden_states: torch.Tensor, |
| | attention_dists: torch.Tensor) -> torch.Tensor: |
| | hidden_states = self.pre_norm(hidden_states) |
| | bsz, num_features, _ = hidden_states.shape |
| |
|
| | num_blocks = math.ceil(num_features / self.context_size) |
| | remainder = num_features % self.context_size |
| | if remainder > 0: |
| | |
| | hidden_states = torch.nn.functional.pad( |
| | hidden_states, (0, 0, 0, self.context_size - remainder)) |
| |
|
| | |
| | |
| | query_states = self.to_q(hidden_states) |
| | key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) |
| |
|
| | query_states = query_states.reshape(bsz, num_blocks, self.context_size, |
| | self.num_heads, |
| | -1).transpose(2, 3) |
| | key_states = key_states.reshape(bsz, num_blocks, self.context_size, |
| | self.num_heads, -1).transpose(2, 3) |
| | value_states = value_states.reshape(bsz, num_blocks, self.context_size, |
| | self.num_heads, |
| | -1).transpose(2, 3) |
| |
|
| | |
| | dist = attention_dists.to(hidden_states.device) |
| | rel_pos_emb = self.rel_pos_emb(dist) |
| | rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + |
| | list(rel_pos_emb.shape)) |
| | pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, |
| | dim=-1) * self.scale |
| |
|
| | if remainder > 0: |
| | |
| | mask = torch.ones(self.context_size, |
| | self.context_size, |
| | dtype=bool, |
| | device=hidden_states.device) |
| | mask[:remainder, :remainder] = 0 |
| | mask_value = -torch.finfo(pos_attn.dtype).max |
| | pos_attn[:, -1, :].masked_fill_(mask, mask_value) |
| |
|
| | with torch.nn.attention.sdpa_kernel( |
| | torch.nn.attention.SDPBackend.MATH): |
| | out = F.scaled_dot_product_attention(query_states, |
| | key_states, |
| | value_states, |
| | attn_mask=pos_attn, |
| | scale=self.scale) |
| | out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) |
| | return self.to_out(out[:, :num_features, :]) |
| |
|
| |
|
| | class GraniteSpeechConformerDepthWiseConv1d(nn.Module): |
| | """Wrapper for padded 1D pointwise convolution.""" |
| |
|
| | def __init__(self, |
| | chan_in: int, |
| | chan_out: int, |
| | kernel_size: int, |
| | prefix: str = ""): |
| | super().__init__() |
| | |
| | pad = kernel_size // 2 |
| | pad_offset = (kernel_size + 1) % 2 |
| | self.padding = (pad, pad - pad_offset) |
| |
|
| | self.conv = nn.Conv1d(chan_in, |
| | chan_out, |
| | kernel_size, |
| | groups=chan_in, |
| | bias=False) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | hidden_states = F.pad(hidden_states, self.padding) |
| | return self.conv(hidden_states) |
| |
|
| |
|
| | class GraniteSpeechConformerConvModule(nn.Module): |
| | """Conformer conv module consisting of several 1D/depthwise 1D |
| | convolutional layers. |
| | """ |
| |
|
| | def __init__(self, config: PretrainedConfig, prefix: str = ""): |
| | super().__init__() |
| | inner_dim = config.hidden_dim * config.conv_expansion_factor |
| |
|
| | self.norm = nn.LayerNorm(config.hidden_dim) |
| | self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) |
| | self.glu = nn.GLU(dim=1) |
| | self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( |
| | inner_dim, |
| | inner_dim, |
| | kernel_size=config.conv_kernel_size, |
| | prefix=f"{prefix}.depth_conv", |
| | ) |
| | self.silu = nn.SiLU() |
| | self.batch_norm = nn.BatchNorm1d(inner_dim) |
| | self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | hidden_states = self.norm(hidden_states) |
| | hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) |
| | hidden_states = self.glu(hidden_states) |
| | hidden_states = self.depth_conv(hidden_states) |
| | hidden_states = self.silu(self.batch_norm(hidden_states)) |
| | hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) |
| | return hidden_states |
| |
|
| |
|
| | class GraniteSpeechConformerBlock(nn.Module): |
| | """Conformer block, consisting largely of linear layers, |
| | attention, and convolutional layers.""" |
| |
|
| | def __init__(self, config: PretrainedConfig, prefix: str = ""): |
| | super().__init__() |
| | self.ff1 = GraniteSpeechConformerFeedForward(config, |
| | prefix=f"{prefix}.ff1") |
| | self.attn = GraniteSpeechConformerAttention(config, |
| | prefix=f"{prefix}.attn") |
| | self.conv = GraniteSpeechConformerConvModule(config, |
| | prefix=f"{prefix}.conv") |
| | self.ff2 = GraniteSpeechConformerFeedForward(config, |
| | prefix=f"{prefix}.ff2") |
| | self.post_norm = nn.LayerNorm(config.hidden_dim) |
| |
|
| | def forward(self, hidden_states: torch.Tensor, |
| | attention_dists: torch.Tensor) -> torch.Tensor: |
| | hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states |
| | hidden_states = self.attn( |
| | hidden_states, attention_dists=attention_dists) + hidden_states |
| | hidden_states = self.conv(hidden_states) + hidden_states |
| | hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states |
| | hidden_states = self.post_norm(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class GraniteSpeechCTCEncoder(nn.Module): |
| | """CTC Encoder comprising conformer blocks and additional linear layers.""" |
| |
|
| | def __init__(self, |
| | config: PretrainedConfig, |
| | prefix: str, |
| | quant_config: Optional[QuantizationConfig] = None): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | seq = torch.arange(config.context_size) |
| | relpos_dist = seq.view(-1, 1) - seq.view(1, -1) |
| | self.attention_dists = torch.clamp( |
| | relpos_dist, -config.context_size, |
| | config.context_size) + config.max_pos_emb |
| |
|
| | self.input_linear = nn.Linear(config.input_dim, |
| | config.hidden_dim, |
| | bias=True) |
| | self.layers = nn.ModuleList([ |
| | GraniteSpeechConformerBlock( |
| | config, |
| | prefix=f"{prefix}.layers.{idx}", |
| | ) for idx in range(config.num_layers) |
| | ]) |
| |
|
| | self.out = ColumnParallelLinear( |
| | input_size=config.hidden_dim, |
| | output_size=config.output_dim, |
| | bias=True, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.out", |
| | ) |
| |
|
| | self.out_mid = RowParallelLinear( |
| | input_size=config.output_dim, |
| | output_size=config.hidden_dim, |
| | bias=True, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.out_mid", |
| | ) |
| | self.softmax = nn.Softmax(dim=-1) |
| | self.num_layers = config.num_layers |
| |
|
| | def forward(self, hidden_states: torch.Tensor): |
| | hidden_states = self.input_linear(hidden_states) |
| | for idx, layer in enumerate(self.layers, start=1): |
| | hidden_states = layer(hidden_states, |
| | attention_dists=self.attention_dists) |
| |
|
| | if idx == self.num_layers // 2: |
| | hidden_states_mid = hidden_states.clone() |
| | hidden_states_mid, _ = self.out(hidden_states_mid) |
| | hidden_states_mid = self.softmax(hidden_states_mid) |
| | hidden_states_mid, _ = self.out_mid(hidden_states_mid) |
| | hidden_states += hidden_states_mid |
| | return hidden_states |
| |
|
| |
|
| | @MULTIMODAL_REGISTRY.register_processor( |
| | GraniteSpeechMultiModalProcessor, |
| | info=GraniteSpeechMultiModalProcessingInfo, |
| | dummy_inputs=GraniteSpeechDummyInputsBuilder) |
| | class GraniteSpeechForConditionalGeneration( |
| | nn.Module, |
| | SupportsMultiModal, |
| | SupportsPP, |
| | SupportsLoRA, |
| | ): |
| |
|
| | packed_modules_mapping = { |
| | "qkv_proj": [ |
| | "q_proj", |
| | "k_proj", |
| | "v_proj", |
| | ], |
| | "gate_up_proj": [ |
| | "gate_proj", |
| | "up_proj", |
| | ], |
| | } |
| |
|
| | def __init__(self, *, vllm_config: VllmConfig, prefix: str): |
| | super().__init__() |
| | config = vllm_config.model_config.hf_config |
| | quant_config = vllm_config.quant_config |
| | cache_config = vllm_config.cache_config |
| |
|
| | self.config = config |
| | self.quant_config = quant_config |
| | self.cache_config = cache_config |
| | self.sampler = get_sampler() |
| |
|
| | |
| | self.language_model = init_vllm_registered_model( |
| | vllm_config=vllm_config, |
| | hf_config=config.text_config, |
| | prefix=maybe_prefix(prefix, "language_model"), |
| | ) |
| |
|
| | |
| | self.encoder = GraniteSpeechCTCEncoder( |
| | config=config.encoder_config, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.encoder", |
| | ) |
| |
|
| | |
| | self.projector = GraniteSpeechEncoderProjector( |
| | config=config, |
| | quant_config=quant_config, |
| | cache_config=cache_config, |
| | prefix=f"{prefix}.projector", |
| | ) |
| |
|
| | self.make_empty_intermediate_tensors = ( |
| | self.language_model.make_empty_intermediate_tensors) |
| |
|
| | def _parse_and_validate_audio_input( |
| | self, |
| | **kwargs: object, |
| | ) -> Optional[GraniteSpeechAudioInputs]: |
| | input_features = kwargs.pop("input_features", None) |
| | input_features_mask = kwargs.pop("input_features_mask", None) |
| | audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) |
| | if input_features is None: |
| | return None |
| |
|
| | |
| | |
| | |
| | |
| | if input_features_mask is None: |
| | input_features_mask = self._build_input_features_mask( |
| | audio_embed_sizes) |
| |
|
| | if not isinstance(input_features, (torch.Tensor, list)): |
| | raise ValueError("Incorrect type of audio input features. " |
| | f"Got type: {type(input_features)}") |
| |
|
| | if input_features_mask is not None and not isinstance( |
| | input_features_mask, torch.Tensor): |
| | raise ValueError("Incorrect type of audio input features mask. " |
| | f"Got type: {type(input_features_mask)}") |
| |
|
| | if isinstance(input_features, torch.Tensor): |
| | |
| | |
| | |
| | |
| | |
| | if len(input_features.shape) == 4: |
| | input_features = input_features.squeeze(1) |
| | if len(input_features.shape) != 3: |
| | raise ValueError( |
| | "Squeezed input features should be 3D but are of shape " |
| | f"{input_features.shape}") |
| | input_features = input_features.to( |
| | self.encoder.input_linear.weight.dtype) |
| |
|
| | else: |
| | |
| | |
| | |
| | input_features = self._pad_and_stack_input_features( |
| | input_features, ).to(self.encoder.input_linear.weight.dtype) |
| |
|
| | return GraniteSpeechAudioInputs( |
| | input_features=input_features, |
| | input_features_mask=input_features_mask, |
| | audio_embed_sizes=audio_embed_sizes.flatten().tolist(), |
| | ) |
| |
|
| | def _build_input_features_mask( |
| | self, |
| | audio_embed_sizes: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Calculate the input features mask, which will generally be used |
| | to mask the padded features for all entries in the batch except |
| | for those with the most audio features. |
| | |
| | Args: |
| | audio_embed_sizes: torch.Tensor |
| | Tensor of num features in each seq in the batch. |
| | Returns: |
| | torch.Tensor: Mask of shape (bsz, num_features) to be applied to |
| | the audio features prior to splitting the audio embeddings. |
| | """ |
| | most_audio_features = torch.max(audio_embed_sizes).item() |
| | mask_indices = torch.arange( |
| | most_audio_features, |
| | device=audio_embed_sizes.device, |
| | ).view(1, -1) |
| | input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1) |
| | return input_features_mask |
| |
|
| | def _pad_and_stack_input_features( |
| | self, |
| | input_features: list[torch.Tensor], |
| | ) -> torch.Tensor: |
| | """Given a list of input features of varying length, pad them to the |
| | same length and stack them into a torch.Tensor. |
| | |
| | NOTE: Usually, padding is done in the input processor/feature extractor |
| | and zero padded prior to the computation of the Mel features; the |
| | resulting values are only constant within a batch and generally nonzero |
| | (i.e., slightly negative nums); we should validate that this is okay |
| | since we don't use a feature attention mask, but the more important |
| | thing is that we apply the input_features_mask with variable len |
| | batches. |
| | |
| | Args: |
| | input_features: list[torch.Tensor] |
| | Input features to be coerced into a tensor. |
| | Returns: |
| | torch.Tensor: Tensor of shape [bsz, num_features, 160], where |
| | num_features is the max number of features of any entry in the |
| | batch. |
| | """ |
| | |
| | feat_lens = [feats.shape[1] for feats in input_features] |
| | padding = [max(feat_lens) - length for length in feat_lens] |
| | |
| | |
| | |
| | padded = [ |
| | torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0)) |
| | for feats, pad in zip(input_features, padding) |
| | ] |
| | stacked_features = torch.cat(padded, dim=0).to(input_features[0]) |
| | return stacked_features |
| |
|
| | def _process_audio_input( |
| | self, |
| | audio_input: GraniteSpeechAudioInputs, |
| | ) -> tuple[torch.Tensor]: |
| | """Compute the audio features to be merged into the LLM embeddings. |
| | |
| | Args: |
| | audio_input: GraniteSpeechAudioInputs |
| | Audio inputs object containing Mel features, an input features |
| | mask, and the (flattened) number of audio tokens per instance. |
| | Returns: |
| | tuple[torch.Tensor]: List of length bsz. |
| | """ |
| | |
| | encoder_embeds = self.encoder(audio_input["input_features"]) |
| | |
| | projected_embeds = self.projector(encoder_embeds) |
| | |
| | masked_embeds = projected_embeds[audio_input["input_features_mask"]] |
| | |
| | return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) |
| |
|
| | def get_multimodal_embeddings( |
| | self, |
| | **kwargs: object, |
| | ) -> Optional[MultiModalEmbeddings]: |
| | """Compute the audio embeddings if audio inputs are present.""" |
| | audio_input = self._parse_and_validate_audio_input(**kwargs) |
| | if audio_input is None: |
| | return None |
| | audio_features = self._process_audio_input(audio_input) |
| | return audio_features |
| |
|
| | def get_input_embeddings( |
| | self, |
| | input_ids: torch.Tensor, |
| | multimodal_embeddings: Optional[MultiModalEmbeddings] = None, |
| | ) -> torch.Tensor: |
| | """Compute the merged LLM / audio embeddings.""" |
| | if multimodal_embeddings is None: |
| | return self.language_model.get_input_embeddings(input_ids) |
| |
|
| | inputs_embeds = embed_multimodal( |
| | input_ids, |
| | self.config.audio_token_index, |
| | self.language_model.model.get_input_embeddings, |
| | multimodal_embeddings, |
| | ) |
| | return inputs_embeds |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | positions: torch.Tensor, |
| | intermediate_tensors: Optional[IntermediateTensors] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | **kwargs: object, |
| | ) -> Union[torch.Tensor, IntermediateTensors]: |
| | if intermediate_tensors is not None: |
| | inputs_embeds = None |
| |
|
| | |
| | |
| | elif inputs_embeds is None: |
| | audio_embeds = self.get_multimodal_embeddings(**kwargs) |
| | inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) |
| | input_ids = None |
| |
|
| | model_output = self.language_model(input_ids, positions, |
| | intermediate_tensors, inputs_embeds) |
| | return model_output |
| |
|
| | def compute_logits( |
| | self, |
| | hidden_states: torch.Tensor, |
| | sampling_metadata: SamplingMetadata, |
| | ) -> Optional[torch.Tensor]: |
| | return self.language_model.compute_logits( |
| | hidden_states, |
| | sampling_metadata, |
| | ) |
| |
|
| | def load_weights( |
| | self, |
| | weights: Iterable[tuple[str, torch.Tensor]], |
| | ) -> set[str]: |
| | loader = AutoWeightsLoader(self) |
| | return loader.load_weights(weights) |
| |
|
| | def get_mm_mapping(self) -> MultiModelKeys: |
| | """Get the module prefix in multimodal models.""" |
| | return MultiModelKeys.from_string_field( |
| | language_model="language_model", |
| | connector="projector", |
| | tower_model="encoder", |
| | ) |
| |
|