Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional | |
| import torch | |
| from torch import nn as nn | |
| from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin | |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
| class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): | |
| """ | |
| Override some HuggingFace interface methods so we can use the standard `generate` method with our | |
| custom embedding / logit layers. | |
| NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights! | |
| """ | |
| def __init__( | |
| self, | |
| config: LlamaConfig, | |
| llama: LlamaModel, | |
| *, | |
| speech_enc, | |
| speech_head, | |
| latents_queue=None, | |
| logits_queue=None, | |
| alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None, | |
| ): | |
| super().__init__(config) | |
| self.model = llama | |
| self.speech_enc = speech_enc | |
| self.speech_head = speech_head | |
| self._added_cond = False | |
| self.alignment_stream_analyzer = alignment_stream_analyzer | |
| def prepare_inputs_for_generation( | |
| self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None, | |
| # This argument was introduced in some recent version of transformers (>=4.29.1) | |
| cache_position=None | |
| ): | |
| """ | |
| This is a method used by huggingface's generate() method. | |
| Overridden here to apply our custom speech token embedding layer. | |
| :param input_ids: (B, S) int64 tensors of input tokens. | |
| :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to <input_embeds>) | |
| """ | |
| # Make use of the kv cache: only the last input ID is new, we trim away all the ones before | |
| if not use_cache: | |
| past_key_values = None | |
| if past_key_values is not None: | |
| input_ids = input_ids[:, -1:] | |
| # custom speech token embedding layer | |
| inputs_embeds = self.speech_enc(input_ids) | |
| # prefix decoder conditioning if applicable | |
| if not self._added_cond: | |
| assert past_key_values is not None # should be first step | |
| if decoder_cond.size(0) != inputs_embeds.size(0): | |
| decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1) | |
| inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1) | |
| self._added_cond = True | |
| return { | |
| "inputs_embeds": inputs_embeds, | |
| "past_key_values": past_key_values, | |
| "use_cache": use_cache, | |
| } | |
| def forward( | |
| self, | |
| inputs_embeds: torch.Tensor, | |
| past_key_values: Optional[torch.Tensor]=None, | |
| use_cache=True, | |
| output_attentions=False, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| ): | |
| """ | |
| This is a method used by huggingface's generate() method. | |
| Overridden here to apply our custom layer norm and speech logit projection layers. | |
| :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given, | |
| S should be 1. | |
| """ | |
| is_large_input = inputs_embeds.size(1) != 1 | |
| has_cache = past_key_values is not None and len(past_key_values) > 0 | |
| assert not (is_large_input and has_cache) | |
| assert return_dict | |
| assert output_hidden_states | |
| tfmr_out = self.model( | |
| inputs_embeds=inputs_embeds, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| ) | |
| hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim) | |
| logits = self.speech_head(hidden_states) | |
| # assert inputs_embeds.size(0) == 1 # (disabled for CFG) | |
| # NOTE: hallucination handler may modify logits to force emit an EOS token | |
| # logits = self.alignment_stream_analyzer.step(logits) | |
| return CausalLMOutputWithCrossAttentions( | |
| logits=logits, | |
| past_key_values=tfmr_out.past_key_values, | |
| hidden_states=tfmr_out.hidden_states, | |
| attentions=tfmr_out.attentions, | |
| ) | |