|
|
|
|
|
|
| import logging
|
| import torch
|
| from dataclasses import dataclass
|
| from types import MethodType
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
|
|
|
|
|
| @dataclass
|
| class AlignmentAnalysisResult:
|
|
|
| false_start: bool
|
|
|
| long_tail: bool
|
|
|
| repetition: bool
|
|
|
| discontinuity: bool
|
|
|
| complete: bool
|
|
|
| position: int
|
|
|
|
|
| class AlignmentStreamAnalyzer:
|
| def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
|
| """
|
| Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
|
| activation maps. This module exploits this to perform online integrity checks which streaming.
|
| A hook is injected into the specified attention layer, and heuristics are used to determine alignment
|
| position, repetition, etc.
|
|
|
| NOTE: currently requires no queues.
|
| """
|
|
|
| self.text_tokens_slice = (i, j) = text_tokens_slice
|
| self.eos_idx = eos_idx
|
| self.alignment = torch.zeros(0, j-i)
|
|
|
| self.curr_frame_pos = 0
|
| self.text_position = 0
|
|
|
| self.started = False
|
| self.started_at = None
|
|
|
| self.complete = False
|
| self.completed_at = None
|
|
|
|
|
| self.generated_tokens = []
|
|
|
|
|
|
|
|
|
| self.last_aligned_attns = []
|
| for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
|
| self.last_aligned_attns += [None]
|
| self._add_attention_spy(tfmr, i, layer_idx, head_idx)
|
|
|
| def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
|
| """
|
| Adds a forward hook to a specific attention layer to collect outputs.
|
| """
|
| def attention_forward_hook(module, input, output):
|
| """
|
| See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
| NOTE:
|
| - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
| - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
| """
|
| if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
|
| step_attention = output[1].cpu()
|
| self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx]
|
|
|
| target_layer = tfmr.layers[layer_idx].self_attn
|
|
|
| target_layer.register_forward_hook(attention_forward_hook)
|
| if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
|
| self.original_output_attentions = tfmr.config.output_attentions
|
| tfmr.config.output_attentions = True
|
|
|
| def step(self, logits, next_token=None):
|
| """
|
| Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
| """
|
|
|
| aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0)
|
| i, j = self.text_tokens_slice
|
| if self.curr_frame_pos == 0:
|
|
|
| A_chunk = aligned_attn[j:, i:j].clone().cpu()
|
| else:
|
|
|
| A_chunk = aligned_attn[:, i:j].clone().cpu()
|
|
|
|
|
| A_chunk[:, self.curr_frame_pos + 1:] = 0
|
|
|
|
|
| self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
|
|
|
| A = self.alignment
|
| T, S = A.shape
|
|
|
|
|
| cur_text_posn = A_chunk[-1].argmax()
|
| discontinuity = not(-4 < cur_text_posn - self.text_position < 7)
|
| if not discontinuity:
|
| self.text_position = cur_text_posn
|
|
|
|
|
|
|
|
|
| false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
|
| self.started = not false_start
|
| if self.started and self.started_at is None:
|
| self.started_at = T
|
|
|
|
|
| self.complete = self.complete or self.text_position >= S - 3
|
| if self.complete and self.completed_at is None:
|
| self.completed_at = T
|
|
|
|
|
|
|
| last_text_token_duration = A[15:, -3:].sum()
|
|
|
|
|
| long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5)
|
|
|
|
|
| alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
|
|
|
|
| if next_token is not None:
|
|
|
| if isinstance(next_token, torch.Tensor):
|
| token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
|
| else:
|
| token_id = next_token
|
| self.generated_tokens.append(token_id)
|
|
|
|
|
| if len(self.generated_tokens) > 8:
|
| self.generated_tokens = self.generated_tokens[-8:]
|
|
|
|
|
|
|
| token_repetition = (
|
|
|
| len(self.generated_tokens) >= 10 and
|
| len(set(self.generated_tokens[-10:])) == 1
|
| )
|
|
|
| if token_repetition:
|
| repeated_token = self.generated_tokens[-1]
|
| logger.warning(f"🚨 Detected 10x repetition of token {repeated_token}")
|
|
|
|
|
| if cur_text_posn < S - 3 and S > 5:
|
| logits[..., self.eos_idx] = -2**15
|
|
|
|
|
|
|
| if long_tail or alignment_repetition or token_repetition:
|
| logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
|
|
|
| logits = -(2**15) * torch.ones_like(logits)
|
| logits[..., self.eos_idx] = 2**15
|
|
|
| self.curr_frame_pos += 1
|
| return logits
|
|
|