| |
|
|
| from typing import Any, Literal, Optional |
|
|
| import torch |
| |
| |
|
|
| from litgpt.model import GPT |
| from utils.snac_utils import layershift, snac_config |
| from tqdm import tqdm |
|
|
|
|
| def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: |
| if torch._dynamo.is_compiling(): |
| |
| distribution = torch.empty_like(probs).exponential_(1) |
| return torch.argmax(probs / distribution, dim=-1, keepdim=True) |
| return torch.multinomial(probs, num_samples=1) |
|
|
|
|
| def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
| |
| |
| |
| sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
| |
| |
| sorted_indices_to_remove[-1:] = 0 |
| indices_to_remove = sorted_indices_to_remove.scatter( |
| 0, sorted_indices, sorted_indices_to_remove |
| ) |
| logits = logits.masked_fill(indices_to_remove, float("-inf")) |
| return logits |
|
|
|
|
| def sample( |
| logits: torch.Tensor, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| ) -> torch.Tensor: |
| if top_p < 0.0 or top_p > 1.0: |
| raise ValueError(f"top_p must be in [0, 1], got {top_p}") |
| logits = logits[0, -1] |
| |
| if top_k is not None: |
| v, i = torch.topk(logits, min(top_k, logits.size(-1))) |
| |
| logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) |
| |
| if temperature > 0.0 or top_p > 0.0: |
| if temperature > 0.0: |
| logits = logits / temperature |
| |
| if top_p < 1.0: |
| logits = sample_top_p(logits, top_p) |
| probs = torch.nn.functional.softmax(logits, dim=-1) |
| return multinomial_num_samples_1(probs) |
| return torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
| def next_token( |
| model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any |
| ) -> torch.Tensor: |
| input_pos = input_pos.to(model.device) |
| logits_a, logit_t = model(None, x, None, input_pos) |
|
|
| next_audio_tokens = [] |
| for logit_a in logits_a: |
| next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype) |
| next_audio_tokens.append(next_a) |
| next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype) |
| return next_audio_tokens, next_t |
|
|
|
|
| def next_token_asr( |
| model: GPT, |
| input_pos: torch.Tensor, |
| audio_features: torch.tensor, |
| lens: int, |
| input_ids: list, |
| **kwargs: Any, |
| ) -> torch.Tensor: |
| input_pos = input_pos.to(model.device) |
| input_ids = [input_id.to(model.device) for input_id in input_ids] |
| logits_a, logit_t = model(audio_features, input_ids, None, input_pos, whisper_lens=lens) |
|
|
| next_audio_tokens = [] |
| for logit_a in logits_a: |
| next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) |
| next_audio_tokens.append(next_a) |
| next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) |
| return next_audio_tokens, next_t |
|
|
|
|
| def next_token_A1T2( |
| model: GPT, |
| audio_features: torch.tensor, |
| input_ids: list, |
| whisper_lens: int, |
| task: list, |
| input_pos: torch.Tensor, |
| **kwargs: Any, |
| ) -> torch.Tensor: |
| input_pos = input_pos.to(model.device) |
| input_ids = [input_id.to(model.device) for input_id in input_ids] |
| logits_a, logit_t = model( |
| audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task |
| ) |
|
|
| next_audio_tokens = [] |
| for logit_a in logits_a: |
| next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) |
| next_audio_tokens.append(next_a) |
| next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) |
| return next_audio_tokens, next_t |
|
|
|
|
| def next_token_A1T1( |
| model: GPT, |
| audio_features: torch.tensor, |
| input_ids: list, |
| whisper_lens: int, |
| task: list, |
| input_pos: torch.Tensor, |
| **kwargs: Any, |
| ) -> torch.Tensor: |
| input_pos = input_pos.to(model.device) |
| input_ids = [input_id.to(model.device) for input_id in input_ids] |
| logits_a, logit_t = model( |
| audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task |
| ) |
| next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) |
| return next_t |
|
|
|
|
| def next_token_image_batch(model: GPT, |
| audio_features: torch.tensor, |
| clip_features: torch.tensor, |
| input_ids: list, |
| whisper_lens: int, |
| task: list, |
| input_pos: torch.Tensor, |
| **kwargs: Any) -> torch.Tensor: |
| input_pos = input_pos.to(model.device) |
| input_ids = [input_id.to(model.device) for input_id in input_ids] |
| logits_a,logit_t = model(audio_features, input_ids, clip_features, |
| input_pos, whisper_lens=whisper_lens, task=task) |
|
|
| for i in range(7): |
| logits_a[i] = logits_a[i][0].unsqueeze(0) |
| logit_t = logit_t[1].unsqueeze(0) |
|
|
| next_audio_tokens = [] |
| for logit_a in logits_a: |
| next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) |
| next_audio_tokens.append(next_a) |
| next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) |
| return next_audio_tokens, next_t |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| @torch.inference_mode() |
| def generate( |
| model: GPT, |
| input_ids: list, |
| max_returned_tokens: int, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
| |
| |
| |
| """ |
| Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. |
| The implementation of this function is modified from A. Karpathy's nanoGPT. |
| |
| Args: |
| model: The model to use. |
| prompt: Tensor of shape (T) with indices of the prompt sequence. |
| max_returned_tokens: The maximum number of tokens to return (given plus generated). |
| temperature: Scales the predicted logits by 1 / temperature. |
| top_k: If specified, only sample among the tokens with the k highest probabilities. |
| top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. |
| In top-p sampling, the next token is sampled from the highest probability tokens |
| whose cumulative probability exceeds the threshold `top_p`. When specified, |
| it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent |
| to sampling the most probable token, while `top_p=1` samples from the whole distribution. |
| It can be used in conjunction with `top_k` and `temperature` with the following order |
| of application: |
| |
| 1. `top_k` sampling |
| 2. `temperature` scaling |
| 3. `top_p` sampling |
| |
| For more details, see https://arxiv.org/abs/1904.09751 |
| or https://huyenchip.com/2024/01/16/sampling.html#top_p |
| eos_id: If specified, stop generating any more token once the <eos> token is triggered. |
| include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output. |
| """ |
| T = input_ids[0].size(0) |
| device = input_ids[0].device |
| assert max_returned_tokens > T |
| if model.max_seq_length < max_returned_tokens - 1: |
| |
| |
| |
| raise NotImplementedError( |
| f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" |
| ) |
|
|
| for input_id in input_ids: |
| input_id = [input_id] |
| ( |
| tokens_A1, |
| tokens_A2, |
| tokens_A3, |
| tokens_A4, |
| tokens_A5, |
| tokens_A6, |
| tokens_A7, |
| tokens_T, |
| ) = input_ids |
|
|
| tokens_A1_output = [tokens_A1] |
| tokens_A2_output = [tokens_A2] |
| tokens_A3_output = [tokens_A3] |
| tokens_A4_output = [tokens_A4] |
| tokens_A5_output = [tokens_A5] |
| tokens_A6_output = [tokens_A6] |
| tokens_A7_output = [tokens_A7] |
| tokens_T_output = [tokens_T] |
|
|
| list_output = [ |
| tokens_A1_output, |
| tokens_A2_output, |
| tokens_A3_output, |
| tokens_A4_output, |
| tokens_A5_output, |
| tokens_A6_output, |
| tokens_A7_output, |
| tokens_T_output, |
| ] |
|
|
| input_pos = torch.tensor([T], device=device) |
| model_input_ids = [ |
| tokens_A1.view(1, -1), |
| tokens_A2.view(1, -1), |
| tokens_A3.view(1, -1), |
| tokens_A4.view(1, -1), |
| tokens_A5.view(1, -1), |
| tokens_A6.view(1, -1), |
| tokens_A7.view(1, -1), |
| tokens_T.view(1, -1), |
| ] |
|
|
| tokens_A, token_T = next_token( |
| model, |
| torch.arange(0, T, device=device), |
| model_input_ids, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| for i in range(7): |
| list_output[i].append(tokens_A[i].clone()) |
| list_output[7].append(token_T.clone()) |
|
|
| |
| for i in range(7): |
| tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size |
| token_T = token_T.clone() |
|
|
| text_end = False |
| max_returned_tokens = 1000 |
| for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
| model_input_ids = [ |
| token_a.view(1, -1).to(torch.int32) for token_a in tokens_A |
| ] + [token_T.view(1, -1).to(torch.int32)] |
| tokens_A, token_T = next_token( |
| model, |
| input_pos, |
| model_input_ids, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| if text_end: |
| token_T = torch.tensor([pad_id], device=device) |
|
|
| for i in range(7): |
| list_output[i].append(tokens_A[i].clone()) |
| list_output[7].append(token_T.clone()) |
|
|
| if tokens_A[-1] == eos_id_a: |
| break |
| if token_T == eos_id_t: |
| if generate_text: |
| break |
| text_end = True |
|
|
| for i in range(7): |
| tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size |
| token_T = token_T.clone() |
| input_pos = input_pos.add_(1) |
|
|
| for i in range(len(list_output)): |
| list_output[i] = torch.cat(list_output[i]) |
| return list_output |
|
|
|
|
| @torch.inference_mode() |
| def generate_TA_BATCH( |
| model: GPT, |
| audio_features: torch.Tensor, |
| input_ids: list, |
| leng, |
| task, |
| max_returned_tokens: int = 1000, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id_t: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
|
|
| T = input_ids[0].size(1) |
| device = input_ids[0].device |
| assert max_returned_tokens > T |
| if model.max_seq_length < max_returned_tokens - 1: |
| raise NotImplementedError( |
| f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" |
| ) |
|
|
| input_pos = torch.tensor([T], device=device) |
| model_input_ids = input_ids |
|
|
| list_output = [[] for i in range(8)] |
|
|
| tokens_A, token_T = next_token_image_batch( |
| model, |
| audio_features.to(torch.float32).to(model.device), |
| None, |
| input_ids, |
| [T - 3, T - 3], |
| ["A1T2", "A1T2"], |
| input_pos=torch.arange(0, T, device=device), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| for i in range(7): |
| list_output[i].append(tokens_A[i].tolist()[0]) |
| list_output[7].append(token_T.tolist()[0]) |
|
|
| model_input_ids = [[] for i in range(8)] |
| for i in range(7): |
| tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size |
| model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) |
| model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)) |
| model_input_ids[i] = torch.stack(model_input_ids[i]) |
|
|
| model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| model_input_ids[-1] = torch.stack(model_input_ids[-1]) |
|
|
| text_end = False |
|
|
| for _ in range(2, max_returned_tokens - T + 1): |
| tokens_A, token_T = next_token_image_batch( |
| model, |
| None, |
| None, |
| model_input_ids, |
| None, |
| None, |
| input_pos=input_pos, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| if text_end: |
| token_T = torch.tensor([pad_id_t], device=device) |
|
|
| if tokens_A[-1] == eos_id_a: |
| break |
| if token_T == eos_id_t: |
| text_end = True |
|
|
| for i in range(7): |
| list_output[i].append(tokens_A[i].tolist()[0]) |
| list_output[7].append(token_T.tolist()[0]) |
|
|
| model_input_ids = [[] for i in range(8)] |
| for i in range(7): |
| tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size |
| model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) |
| model_input_ids[i].append( |
| torch.tensor([layershift(snac_config.end_of_audio, i)], device=device) |
| ) |
| model_input_ids[i] = torch.stack(model_input_ids[i]) |
|
|
| model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| model_input_ids[-1] = torch.stack(model_input_ids[-1]) |
|
|
| input_pos = input_pos.add_(1) |
|
|
| return list_output |
|
|
|
|
| @torch.inference_mode() |
| def generate_TT( |
| model: GPT, |
| audio_features: torch.Tensor, |
| input_ids: list, |
| leng, |
| task, |
| max_returned_tokens: int = 2048, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id_t: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
|
|
| T = input_ids[0].size(1) |
| device = input_ids[0].device |
|
|
| output = [] |
| token_T = next_token_A1T1( |
| model, |
| None, |
| input_ids, |
| None, |
| None, |
| input_pos=torch.arange(0, T, device=device), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| output.append(token_T.clone().tolist()[0]) |
| input_pos = torch.tensor([T], device=device) |
|
|
| for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
| model_input_ids = [] |
| for i in range(7): |
| model_input_ids.append( |
| torch.tensor([layershift(snac_config.end_of_audio, i)]) |
| .view(1, -1) |
| .to(torch.int32) |
| .to(device) |
| ) |
| model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) |
| token_T = next_token_A1T1( |
| model, |
| None, |
| model_input_ids, |
| None, |
| None, |
| input_pos=input_pos, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| if token_T == eos_id_t: |
| break |
| output.append(token_T.clone().tolist()[0]) |
| input_pos = input_pos.add_(1) |
| return output |
|
|
|
|
| @torch.inference_mode() |
| def generate_AT( |
| model: GPT, |
| audio_features: torch.Tensor, |
| input_ids: list, |
| leng, |
| task, |
| max_returned_tokens: int = 2048, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id_t: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
|
|
| T = input_ids[0].size(1) |
| device = input_ids[0].device |
|
|
| output = [] |
| token_T = next_token_A1T1( |
| model, |
| audio_features.to(torch.float32).to(model.device), |
| input_ids, |
| [T - 3], |
| ["AT"], |
| input_pos=torch.arange(0, T, device=device), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| output.append(token_T.clone().tolist()[0]) |
| input_pos = torch.tensor([T], device=device) |
| text_end = False |
| for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
| model_input_ids = [] |
| for i in range(7): |
| model_input_ids.append( |
| torch.tensor([layershift(snac_config.end_of_audio, i)]) |
| .view(1, -1) |
| .to(torch.int32) |
| .to(device) |
| ) |
| model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) |
| token_T = next_token_A1T1( |
| model, |
| None, |
| model_input_ids, |
| None, |
| None, |
| input_pos=input_pos, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| if token_T == eos_id_t: |
| break |
| output.append(token_T.clone().tolist()[0]) |
| input_pos = input_pos.add_(1) |
| return output |
|
|
|
|
| @torch.inference_mode() |
| def generate_TA( |
| model: GPT, |
| audio_features: torch.Tensor, |
| input_ids: list, |
| leng, |
| task, |
| max_returned_tokens: int = 2048, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id_t: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
|
|
| T = input_ids[0].size(1) |
| device = input_ids[0].device |
|
|
| output = [[] for _ in range(8)] |
| tokens_A, token_T = next_token_A1T2( |
| model, |
| None, |
| input_ids, |
| None, |
| None, |
| input_pos=torch.arange(0, T, device=device), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| for i in range(7): |
| output[i].append(tokens_A[i].clone().tolist()[0]) |
| output[7].append(token_T.clone().tolist()[0]) |
|
|
| input_pos = torch.tensor([T], device=device) |
| text_end = False |
| for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
|
|
| model_input_ids = [] |
| for i in range(7): |
| model_input_ids.append( |
| layershift(tokens_A[i].clone(), i) |
| .view(1, -1) |
| .to(torch.int32) |
| .to(device) |
| ) |
| model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) |
|
|
| tokens_A, token_T = next_token_A1T2( |
| model, |
| None, |
| model_input_ids, |
| None, |
| None, |
| input_pos=input_pos, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| if text_end: |
| token_T = torch.tensor([pad_id_t], device=device) |
|
|
| if tokens_A[-1] == eos_id_a: |
| break |
|
|
| if token_T == eos_id_t: |
| text_end = True |
|
|
| for i in range(7): |
| output[i].append(tokens_A[i].clone().tolist()[0]) |
| output[7].append(token_T.clone().tolist()[0]) |
| input_pos = input_pos.add_(1) |
|
|
| return output |
|
|
|
|
| @torch.inference_mode() |
| def generate_AA( |
| model: GPT, |
| audio_features: torch.Tensor, |
| input_ids: list, |
| leng, |
| task, |
| max_returned_tokens: int = 2048, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id_t: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
|
|
| T = input_ids[0].size(1) |
| device = input_ids[0].device |
|
|
| output = [[] for _ in range(8)] |
| tokens_A, token_T = next_token_A1T2( |
| model, |
| audio_features.to(torch.float32).to(model.device), |
| input_ids, |
| [T - 3], |
| ["A1T2"], |
| input_pos=torch.arange(0, T, device=device), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| for i in range(7): |
| output[i].append(tokens_A[i].clone().tolist()[0]) |
| output[7].append(token_T.clone().tolist()[0]) |
|
|
| input_pos = torch.tensor([T], device=device) |
|
|
| text_end = False |
| for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
|
|
| model_input_ids = [] |
| for i in range(7): |
| model_input_ids.append( |
| layershift(tokens_A[i].clone(), i) |
| .view(1, -1) |
| .to(torch.int32) |
| .to(device) |
| ) |
| model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) |
|
|
| tokens_A, token_T = next_token_A1T2( |
| model, |
| None, |
| model_input_ids, |
| None, |
| None, |
| input_pos=input_pos, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| if text_end: |
| token_T = torch.tensor([pad_id_t], device=device) |
|
|
| if tokens_A[-1] == eos_id_a: |
| break |
| if token_T == eos_id_t: |
| |
| text_end = True |
|
|
| for i in range(7): |
| output[i].append(tokens_A[i].clone().tolist()[0]) |
| output[7].append(token_T.clone().tolist()[0]) |
| input_pos = input_pos.add_(1) |
|
|
| return output |
|
|
|
|
| @torch.inference_mode() |
| def generate_ASR( |
| model: GPT, |
| audio_features: torch.Tensor, |
| input_ids: list, |
| leng, |
| task, |
| max_returned_tokens: int = 1200, |
| *, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: float = 1.0, |
| eos_id_a: Optional[int] = None, |
| eos_id_t: Optional[int] = None, |
| pad_id_t: Optional[int] = None, |
| shift: Optional[int] = None, |
| include_prompt: bool = True, |
| generate_text=False, |
| ) -> torch.Tensor: |
|
|
| T = input_ids[0].size(1) |
| device = input_ids[0].device |
| output = [] |
| token_T = next_token_A1T1( |
| model, |
| audio_features.to(torch.float32).to(model.device), |
| input_ids, |
| [T - 3], |
| ["asr"], |
| input_pos=torch.arange(0, T, device=device), |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| output.append(token_T.clone().tolist()[0]) |
| input_pos = torch.tensor([T], device=device) |
| text_end = False |
| for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
| model_input_ids = [] |
| for i in range(7): |
| model_input_ids.append( |
| torch.tensor([layershift(snac_config.end_of_audio, i)]) |
| .view(1, -1) |
| .to(torch.int32) |
| .to(device) |
| ) |
| model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) |
| token_T = next_token_A1T1( |
| model, |
| None, |
| model_input_ids, |
| None, |
| None, |
| input_pos=input_pos, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
| if token_T == eos_id_t: |
| break |
| output.append(token_T.clone().tolist()[0]) |
| input_pos = input_pos.add_(1) |
| return output |
|
|