| """ |
| Modified From https://github.com/XXXXRT666/GPT-SoVITS |
| """ |
|
|
| import gc |
| import os |
| import time |
| import traceback |
| from typing import Dict, List, Tuple |
|
|
| import flash_attn |
| import torch |
| import torch.nn as nn |
| from tqdm import tqdm |
|
|
| from AR.models.embedding import ( |
| SinePositionalEmbeddingNested as SinePositionalEmbedding, |
| ) |
| from AR.models.embedding import TokenEmbedding |
| from AR.models.structs import T2SRequest, T2SResult, T2SSession |
| from AR.models.t2s_model_abc import ( |
| AttentionABC, |
| CUDAGraphCacheABC, |
| FeedForward, |
| KVCacheABC, |
| KVCacheNHD, |
| T2SDecoderABC, |
| TorchProfiler, |
| TransformerBlockABC, |
| TransformerDecoderABC, |
| ) |
|
|
| Tensor = torch.Tensor |
|
|
|
|
| class Attention(AttentionABC): |
| def __init__(self, n_head: int, hidden_dim: int): |
| super().__init__() |
| self.n_head = n_head |
| self.hidden_dim = hidden_dim |
| assert hidden_dim % n_head == 0 |
| self.head_dim = hidden_dim // n_head |
|
|
| self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True) |
| self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True) |
|
|
| def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: |
| bsz, seqlen, _ = x.shape |
|
|
| q, k, v = self.in_proj.forward(x).chunk(3, dim=-1) |
|
|
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) |
| k = k.view(bsz, seqlen, self.n_head, self.head_dim) |
| v = v.view(bsz, seqlen, self.n_head, self.head_dim) |
|
|
| attn: Tensor = flash_attn.flash_attn_with_kvcache( |
| q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1 |
| ) |
|
|
| attn = self.dropout.forward(attn) |
|
|
| attn = attn.view(bsz, seqlen, self.hidden_dim) |
|
|
| attn = self.out_proj.forward(attn) |
|
|
| return attn |
|
|
|
|
| class TransformerBlock(TransformerBlockABC): |
| def __init__(self, n_head, ffn_dim, hidden_dim) -> None: |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.attention = Attention(n_head, hidden_dim) |
| self.feed_forward = FeedForward(hidden_dim, ffn_dim) |
| self.attention_norm = nn.LayerNorm([self.hidden_dim]) |
| self.ffn_norm = nn.LayerNorm([self.hidden_dim]) |
|
|
|
|
| class TransformerDecoder(TransformerDecoderABC): |
| def __init__( |
| self, |
| hidden_dim, |
| n_layer, |
| n_head, |
| ffn_dim, |
| vocab_size, |
| max_seq_length, |
| max_batch_size, |
| ) -> None: |
| super().__init__() |
|
|
| self.hidden_dim = hidden_dim |
| self.n_head = n_head |
| assert hidden_dim % n_head == 0 |
|
|
| self.head_dim = hidden_dim // n_head |
| self.vocab_size = vocab_size |
|
|
| self.n_layer = n_layer |
|
|
| self.layers = nn.ModuleList( |
| TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer) |
| ) |
|
|
| self.max_seq_length: int = max_seq_length |
| self.max_batch_size: int = max_batch_size |
|
|
| self.setup_caches(self.max_batch_size, self.max_seq_length) |
|
|
| def setup_caches(self, max_batch_size=10, max_seq_length=2500): |
| self.max_seq_length = max_seq_length |
| self.max_batch_size = max_batch_size |
|
|
|
|
| class T2SDecoder(T2SDecoderABC): |
| def __init__( |
| self, |
| config, |
| *args, |
| norm_first=False, |
| max_seq_length=2500, |
| max_batch_size=10, |
| **kwds, |
| ) -> None: |
| assert torch.cuda.is_available() |
| super().__init__() |
|
|
| hidden_dim = config["model"]["hidden_dim"] |
| embedding_dim = config["model"]["embedding_dim"] |
| n_head = config["model"]["head"] |
| n_layer = config["model"]["n_layer"] |
| vocab_size = config["model"]["vocab_size"] |
| phoneme_vocab_size = config["model"]["phoneme_vocab_size"] |
| p_dropout = config["model"]["dropout"] |
| EOS = config["model"]["EOS"] |
| ffn_dim = hidden_dim * 4 |
| self.norm_first = norm_first |
|
|
| self.n_layer = n_layer |
| self.hidden_dim = hidden_dim |
| self.n_head = n_head |
| assert hidden_dim % n_head == 0 |
|
|
| self.head_dim = hidden_dim // n_head |
| self.embedding_dim = embedding_dim |
| self.vocab_size = vocab_size |
| self.phoneme_vocab_size = phoneme_vocab_size |
| self.p_dropout = p_dropout |
| self.max_seq_length = max_seq_length |
| self.max_batch_size = max_batch_size |
| self.EOS = EOS |
| assert self.EOS == self.vocab_size - 1 |
|
|
| self.bert_proj = nn.Linear(1024, self.embedding_dim) |
| self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) |
| self.ar_text_position = SinePositionalEmbedding( |
| self.embedding_dim, |
| dropout=0.1, |
| scale=False, |
| alpha=True, |
| max_batch_size=max_batch_size, |
| max_seq_len=max_seq_length, |
| ) |
| self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout) |
| self.ar_audio_position = SinePositionalEmbedding( |
| self.embedding_dim, |
| dropout=0.1, |
| scale=False, |
| alpha=True, |
| max_batch_size=max_batch_size, |
| max_seq_len=max_seq_length, |
| ) |
| self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False) |
| self.h: TransformerDecoderABC = TransformerDecoder( |
| hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size |
| ) |
|
|
| self.kv_class = KVCacheNHD |
| self._register_load_state_dict_pre_hook(self.load_hook) |
|
|
| def embed( |
| self, |
| x: List[torch.Tensor], |
| y: torch.Tensor, |
| bert_features: List[torch.Tensor], |
| ): |
| x_nested = torch.nested.nested_tensor(x) |
| assert x_nested.size(0) <= self.max_batch_size |
| bert_features_nested = torch.nested.nested_tensor(list(map(lambda x: x.transpose(0, 1), bert_features))) |
|
|
| x_emb = self.ar_text_embedding.forward(x_nested) |
| bert = self.bert_proj.forward(bert_features_nested) |
| x_emb = x_emb + bert |
| x_pos = self.ar_text_position.prefill(x_emb) |
|
|
| y_nested = torch.nested.nested_tensor(list(y.unbind(0))) |
| y_emb = self.ar_audio_embedding.forward(y_nested) |
| y_pos = self.ar_audio_position.prefill(y_emb) |
|
|
| xy_pos = torch.nested.nested_tensor([torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))]) |
| return xy_pos |
|
|
| def post_forward(self, idx: int, session: T2SSession) -> None: |
| pass |
|
|
| def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]: |
| return list(), dict() |
|
|
|
|
| class CUDAGraphCache(CUDAGraphCacheABC): |
| def __init__( |
| self, |
| decoder: T2SDecoderABC, |
| device: torch.device = torch.device("cpu"), |
| dtype: torch.dtype = torch.float32, |
| ) -> None: |
| super().__init__(decoder, device, dtype) |
|
|
| def release_graph(self, session: T2SSession): |
| if session.id != self.id: |
| self.assigned = False |
| else: |
| del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache |
|
|
| def get_cache_graph(self, session: T2SSession): |
| assert self.graph |
| session.graph = self.graph |
|
|
| session.xy_pos_ = self.xy_pos |
| session.xy_dec_ = self.xy_dec |
| session.input_pos = self.input_pos.copy_(session.input_pos) |
|
|
| for cache, cache_ in zip(self.kv_cache, session.kv_cache): |
| cache.sync_cache(cache_) |
|
|
| def capture_new_graph(self, session: T2SSession): |
| session.xy_pos_ = self.xy_pos.clone() |
| session.xy_dec_ = self.xy_dec.clone() |
| session.input_pos = self.input_pos.clone().copy_(session.input_pos) |
|
|
| args, kwds = self.decoder.pre_forward(session) |
| graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds) |
| session.graph = graph |
|
|
|
|
| class CUDAGraphRunner: |
| def __init__( |
| self, |
| decoder_model: T2SDecoderABC, |
| device: torch.device = torch.device("cpu"), |
| dtype: torch.dtype = torch.float32, |
| ) -> None: |
| assert device.type == "cuda" |
| self.device = device |
| self.dtype = dtype |
|
|
| self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype) |
|
|
| self.graphcache = CUDAGraphCache(decoder_model, device, dtype) |
|
|
| def _handle_request(self, request: T2SRequest): |
| with self.device: |
| decoder = self.decoder_model |
| session = T2SSession(decoder, request, device=self.device, dtype=self.dtype) |
|
|
| t1 = 0.0 |
| infer_speed = 0.0 |
|
|
| torch_profiler = TorchProfiler(request.debug) |
| with torch_profiler.profiler(): |
| for idx in tqdm(range(1500)): |
| if idx == 0: |
| session.kv_cache = decoder.init_cache(session.bsz) |
| xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, session.kv_cache) |
| xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()]) |
| else: |
| if request.use_cuda_graph and session.graph is None and torch.cuda.is_available(): |
| self.graphcache.assign_graph(session) |
|
|
| with torch_profiler.record("AR"): |
| if session.graph: |
| session.xy_pos_.copy_(session.xy_pos) |
| session.graph.replay() |
| xy_dec = session.xy_dec_.clone() |
| else: |
| args, kwds = decoder.pre_forward(session) |
| xy_dec = decoder.h.forward( |
| session.input_pos, |
| session.xy_pos, |
| session.kv_cache, |
| *args, |
| **kwds, |
| ) |
|
|
| decoder.post_forward(idx, session) |
| logits = decoder.ar_predict_layer(xy_dec[:, -1]) |
| session.input_pos.add_(1) |
|
|
| if idx == 0: |
| logits[:, -1] = float("-inf") |
|
|
| with torch_profiler.record("Sampling"): |
| samples = session.sampler.sample( |
| logits=logits, |
| previous_tokens=session.y, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| repetition_penalty=request.repetition_penalty, |
| temperature=request.temperature, |
| ) |
|
|
| session.y = torch.cat([session.y, samples], dim=1) |
|
|
| with torch_profiler.record("EOS"): |
| argmax_token = torch.argmax(logits, dim=-1) |
| sample_token = samples.squeeze(1) |
| EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS) |
|
|
| newly_done_mask = EOS_mask & (~session.completed) |
| newly_done_indices = newly_done_mask.nonzero() |
|
|
| if newly_done_indices.numel() > 0: |
| session.y_results[newly_done_indices[0]] = session.y[ |
| newly_done_indices[0], session.y_len : -1 |
| ].squeeze(0) |
| session.completed[newly_done_indices] = True |
|
|
| if torch.all(session.completed).item(): |
| if session.y.size(1) == 0: |
| session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1) |
| tqdm.write("Bad Zero Prediction") |
| else: |
| tqdm.write( |
| f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}" |
| ) |
| tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s") |
| infer_speed = (idx - 1) / (time.perf_counter() - t1) |
| break |
|
|
| if ( |
| request.early_stop_num != -1 |
| and (session.y.size(1) - session.y_len) > request.early_stop_num |
| ) or idx == 1499: |
| for i in range(session.bsz): |
| if not session.completed[i].item(): |
| session.y_results[i] = session.y[i, session.y_len :] |
| session.completed[i] = True |
| break |
|
|
| with torch_profiler.record("NextPos"): |
| y_emb = decoder.ar_audio_embedding(session.y[:, -1:]) |
| session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb) |
|
|
| if idx == 2: |
| torch_profiler.start() |
| t1 = time.perf_counter() |
|
|
| if idx == 51: |
| torch_profiler.end() |
|
|
| if idx % 100 == 0: |
| match session.device.type: |
| case "cuda": |
| torch.cuda.empty_cache() |
| case "mps": |
| torch.mps.empty_cache() |
| case "xpu": |
| torch.xpu.empty_cache() |
| case "mtia": |
| torch.mtia.empty_cache() |
|
|
| match session.device.type: |
| case "cuda": |
| torch.cuda.empty_cache() |
| case "mps": |
| torch.mps.empty_cache() |
| case "xpu": |
| torch.xpu.empty_cache() |
| case "mtia": |
| torch.mtia.empty_cache() |
| case "cpu": |
| gc.collect() |
|
|
| torch_profiler.end() |
| self.graphcache.release_graph(session) |
| return session.y_results[: request.valid_length], infer_speed |
|
|
| def generate(self, request: T2SRequest): |
| try: |
| result, infer_speed = self._handle_request(request) |
| t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success") |
| except Exception as e: |
| t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc()) |
| return t2s_result |
|
|
| @staticmethod |
| def load_decoder(weights_path: os.PathLike, implement: str = "flash_attn"): |
| print(f"Loading Text2Semantic Weights from {weights_path} with {implement.replace('_', ' ').title()} Implement") |
| module_path = f"AR.models.t2s_model_{implement.lower()}" |
| cls_name = "T2SDecoder" |
| mod = __import__(module_path, fromlist=[cls_name]) |
| decoder_cls: T2SDecoderABC = getattr(mod, cls_name) |
| dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True) |
| config = dict_s1["config"] |
| decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=1) |
| state_dict = dict_s1["weight"] |
| decoder.load_state_dict(state_dict) |
| return decoder.eval() |
|
|