|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Modeling classes for MossTTSDelay. """ |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Tuple, Union |
|
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from transformers.utils import ( |
|
|
add_start_docstrings, |
|
|
add_start_docstrings_to_model_forward, |
|
|
logging, |
|
|
replace_return_docstrings, |
|
|
) |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.models.qwen3 import Qwen3Model |
|
|
from transformers import initialization as init |
|
|
|
|
|
from .configuration_moss_tts import MossTTSDelayConfig |
|
|
from .inference_utils import sample_token, find_last_equal_C |
|
|
|
|
|
try: |
|
|
from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor |
|
|
except Exception: |
|
|
UserMessage = None |
|
|
AssistantMessage = None |
|
|
MossTTSDelayProcessor = None |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_CONFIG_FOR_DOC = "MossTTSDelayConfig" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MossTTSDelayOutputWithPast(ModelOutput): |
|
|
""" |
|
|
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). |
|
|
|
|
|
Args: |
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
|
Weighted sum of channel losses. |
|
|
all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*): |
|
|
Sum of losses for each sample and each channel before averaging. |
|
|
all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Number of non-masked tokens per sample. |
|
|
sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): |
|
|
Loss per sample. |
|
|
channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*): |
|
|
Loss per channel (text head + vq heads). |
|
|
logits (`List[torch.FloatTensor]`, *optional*): |
|
|
List of prediction scores from each head. |
|
|
past_key_values (`Cache`, *optional*): |
|
|
Pre-computed hidden-states (key and values in the self-attention blocks). |
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): |
|
|
Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + |
|
|
one for the output of each layer). |
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): |
|
|
Tuple of torch.FloatTensor (one for each layer) of the attention weights. |
|
|
""" |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
all_sum_losses: Optional[torch.FloatTensor] = None |
|
|
all_token_nums: Optional[torch.LongTensor] = None |
|
|
sample_losses: Optional[torch.FloatTensor] = None |
|
|
channel_losses: Optional[torch.FloatTensor] = None |
|
|
logits: Optional[List[torch.FloatTensor]] = None |
|
|
past_key_values: Optional[Cache] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MossTTSDelayPreTrainedModel(PreTrainedModel): |
|
|
config_class = MossTTSDelayConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen3DecoderLayer"] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_flash_attn = True |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
_supports_flex_attn = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
""" |
|
|
Transformers 5.0+ safe init: |
|
|
- MUST use transformers.initialization helpers |
|
|
- MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params |
|
|
""" |
|
|
|
|
|
super()._init_weights(module) |
|
|
|
|
|
|
|
|
|
|
|
std = None |
|
|
if hasattr(self.config, "initializer_range"): |
|
|
std = self.config.initializer_range |
|
|
elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"): |
|
|
std = self.config.language_config.initializer_range |
|
|
else: |
|
|
std = 0.02 |
|
|
|
|
|
|
|
|
if isinstance(module, nn.Embedding): |
|
|
|
|
|
|
|
|
if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1: |
|
|
init.normal_(module.weight, mean=0.0, std=std) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
MOSSTTS_START_DOCSTRING = r""" |
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
|
etc.) |
|
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
|
and behavior. |
|
|
|
|
|
Parameters: |
|
|
config ([`MossTTSDelayConfig`]): |
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
|
load the weights associated with the model, only the configuration. Check out the |
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
|
""" |
|
|
|
|
|
|
|
|
@add_start_docstrings( |
|
|
"The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.", |
|
|
MOSSTTS_START_DOCSTRING, |
|
|
) |
|
|
class MossTTSDelayModel(MossTTSDelayPreTrainedModel): |
|
|
UserMessage = UserMessage |
|
|
AssistantMessage = AssistantMessage |
|
|
Processor = MossTTSDelayProcessor |
|
|
|
|
|
def __init__(self, config: MossTTSDelayConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
config.language_config.torch_dtype = config.torch_dtype |
|
|
|
|
|
self.language_model = Qwen3Model(config.language_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.emb_ext = nn.ModuleList() |
|
|
for vq_idx in range(self.config.n_vq): |
|
|
|
|
|
self.emb_ext.append( |
|
|
nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lm_heads = nn.ModuleList([ |
|
|
nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False) |
|
|
]) |
|
|
for vq_idx in range(self.config.n_vq): |
|
|
self.lm_heads.append( |
|
|
nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False) |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Computes the combined embeddings from text and multiple audio VQ channels. |
|
|
|
|
|
Args: |
|
|
input_ids: Shape (Batch, Seq_Len, 1 + n_vq) |
|
|
""" |
|
|
|
|
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0]) |
|
|
|
|
|
|
|
|
for i, embed_layer in enumerate(self.emb_ext): |
|
|
|
|
|
|
|
|
inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1]) |
|
|
|
|
|
return inputs_embeds |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.language_model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
|
|
|
|
|
|
return self.lm_heads |
|
|
|
|
|
@add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING) |
|
|
@replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
hidden_out_layers: Optional[List[int]] = None, |
|
|
channelwise_loss_weight: Optional[List[float]] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, MossTTSDelayOutputWithPast]: |
|
|
r""" |
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`): |
|
|
Indices of input sequence tokens in the vocabulary. |
|
|
Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N]. |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. |
|
|
channelwise_loss_weight (`List[float]`, *optional*): |
|
|
Manual weights for summing losses across different heads (Text vs Audio channels). |
|
|
|
|
|
Returns: |
|
|
""" |
|
|
|
|
|
if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1: |
|
|
raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).") |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.get_input_embeddings(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=True, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
if hidden_out_layers is None: |
|
|
|
|
|
|
|
|
|
|
|
hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads)) |
|
|
else: |
|
|
|
|
|
|
|
|
all_hs = outputs.hidden_states |
|
|
hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers] |
|
|
|
|
|
|
|
|
layer_logits = [] |
|
|
for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)): |
|
|
logits = head(hs) |
|
|
|
|
|
|
|
|
|
|
|
if i > 0: |
|
|
logits[..., -1] = float("-inf") |
|
|
layer_logits.append(logits) |
|
|
|
|
|
|
|
|
loss = None |
|
|
all_sum_losses = None |
|
|
all_token_nums = None |
|
|
sample_losses = None |
|
|
channel_losses = None |
|
|
|
|
|
if labels is not None: |
|
|
|
|
|
if labels.dim() != 3: |
|
|
raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}") |
|
|
|
|
|
batch_size = labels.size(0) |
|
|
n_heads = len(layer_logits) |
|
|
|
|
|
|
|
|
|
|
|
all_sum_losses_list = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_token_nums = torch.sum(labels != -100, dim=1) |
|
|
|
|
|
for i, logits in enumerate(layer_logits): |
|
|
|
|
|
|
|
|
cur_labels = labels[..., i] |
|
|
|
|
|
|
|
|
|
|
|
loss_fct = CrossEntropyLoss(reduction='none') |
|
|
vocab_size = logits.size(-1) |
|
|
|
|
|
reshaped_logits = logits.view(-1, vocab_size) |
|
|
reshaped_labels = cur_labels.contiguous().view(-1) |
|
|
|
|
|
|
|
|
per_token_loss = loss_fct(reshaped_logits, reshaped_labels) |
|
|
|
|
|
|
|
|
per_token_loss = per_token_loss.view(batch_size, -1) |
|
|
per_sample_loss = torch.sum(per_token_loss, dim=-1) |
|
|
|
|
|
all_sum_losses_list.append(per_sample_loss) |
|
|
|
|
|
|
|
|
all_sum_losses = torch.stack(all_sum_losses_list, dim=1) |
|
|
|
|
|
|
|
|
if channelwise_loss_weight is not None: |
|
|
if len(channelwise_loss_weight) != n_heads: |
|
|
raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}") |
|
|
|
|
|
w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_counts_safe = all_token_nums.float().clamp(min=1.0) |
|
|
|
|
|
normalized_losses = all_sum_losses / token_counts_safe |
|
|
sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum() |
|
|
|
|
|
|
|
|
total_loss_per_channel = all_sum_losses.sum(dim=0) |
|
|
total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0) |
|
|
channel_losses = total_loss_per_channel / total_tokens_per_channel |
|
|
|
|
|
|
|
|
loss = (channel_losses * w_tensor).sum() / w_tensor.sum() |
|
|
else: |
|
|
|
|
|
total_tokens = all_token_nums.sum().float().clamp(min=1.0) |
|
|
loss = all_sum_losses.sum() / total_tokens |
|
|
channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0) |
|
|
|
|
|
return MossTTSDelayOutputWithPast( |
|
|
loss=loss, |
|
|
all_sum_losses=all_sum_losses, |
|
|
all_token_nums=all_token_nums, |
|
|
sample_losses=sample_losses, |
|
|
channel_losses=channel_losses, |
|
|
logits=layer_logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
max_new_tokens: int = 1000, |
|
|
text_temperature: float = 1.5, |
|
|
text_top_p: float = 1.0, |
|
|
text_top_k: int = 50, |
|
|
audio_temperature: float = 1.7, |
|
|
audio_top_p: float = 0.8, |
|
|
audio_top_k: int = 25, |
|
|
audio_repetition_penalty: float = 1.0, |
|
|
): |
|
|
if text_temperature > 0: |
|
|
text_do_sample = True |
|
|
else: |
|
|
text_temperature = 1 |
|
|
text_do_sample = False |
|
|
if audio_temperature > 0: |
|
|
audio_do_sample = True |
|
|
else: |
|
|
audio_temperature = 1 |
|
|
audio_do_sample = False |
|
|
|
|
|
past_key_values = None |
|
|
device = input_ids.device |
|
|
current_input_ids = input_ids |
|
|
current_attention_mask = attention_mask |
|
|
batch_size, seq_len, n_vq = input_ids.shape |
|
|
n_vq -= 1 |
|
|
|
|
|
generation_ids = input_ids[:] |
|
|
is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
|
|
|
audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) |
|
|
torch_int64_max = torch.iinfo(torch.int64).max |
|
|
delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) |
|
|
|
|
|
is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id) |
|
|
audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id) |
|
|
audio_start_mask = is_continuation & (audio_start_indices != -1) |
|
|
audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask] |
|
|
|
|
|
is_audio = audio_start_mask.clone() |
|
|
|
|
|
pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device) |
|
|
pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool() |
|
|
pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False |
|
|
|
|
|
for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."): |
|
|
outputs = self( |
|
|
input_ids=current_input_ids, |
|
|
attention_mask=current_attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
) |
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] |
|
|
next_token_logits[0] = next_token_logits[0].clone() |
|
|
next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device) |
|
|
next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id |
|
|
is_audio_eos = ~is_stopping & (delayed_lengths == n_vq) |
|
|
next_text_token[is_audio_eos] = self.config.audio_end_token_id |
|
|
is_audio[is_audio_eos] = False |
|
|
sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq) |
|
|
next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf')) |
|
|
next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf')) |
|
|
if time_step == 0: |
|
|
next_token_logits[0][..., 151662] = float('-inf') |
|
|
if time_step <= n_vq: |
|
|
next_token_logits[0][..., self.config.im_end_token_id] = float('-inf') |
|
|
|
|
|
next_text_token[sampling_text_mask] = sample_token( |
|
|
logits=next_token_logits[0][sampling_text_mask], |
|
|
top_p=text_top_p, |
|
|
top_k=text_top_k, |
|
|
do_sample=text_do_sample |
|
|
) |
|
|
is_audio[next_text_token == self.config.audio_start_token_id] = True |
|
|
is_stopping[next_text_token == self.config.im_end_token_id] = True |
|
|
|
|
|
next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device) |
|
|
|
|
|
pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) |
|
|
post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1 |
|
|
post_audio_mask[delayed_lengths == torch_int64_max] = True |
|
|
sampling_audio_mask = pre_audio_mask & post_audio_mask |
|
|
next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code |
|
|
|
|
|
if sampling_audio_mask.sum() > 0: |
|
|
audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]] |
|
|
audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]] |
|
|
audio_ch0_logits[..., self.config.audio_pad_code] = float('-inf') |
|
|
audio_logits[..., self.config.audio_pad_code] = float('-inf') |
|
|
next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token( |
|
|
logits=audio_ch0_logits, |
|
|
prev_tokens=generation_ids[:, :, 1], |
|
|
repetition_penalty=audio_repetition_penalty, |
|
|
top_p=audio_top_p, |
|
|
top_k=audio_top_k, |
|
|
do_sample=audio_do_sample |
|
|
) |
|
|
next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token( |
|
|
logits=audio_logits, |
|
|
prev_tokens=generation_ids[:, :, 2:], |
|
|
repetition_penalty=audio_repetition_penalty, |
|
|
top_p=audio_top_p, |
|
|
top_k=audio_top_k, |
|
|
do_sample=audio_do_sample |
|
|
) |
|
|
|
|
|
audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1 |
|
|
audio_lengths[next_text_token == self.config.audio_end_token_id] = 0 |
|
|
delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0 |
|
|
delayed_lengths[delayed_lengths != torch_int64_max] += 1 |
|
|
delayed_lengths[delayed_lengths > n_vq] = torch_int64_max |
|
|
|
|
|
current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) |
|
|
current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1) |
|
|
generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) |
|
|
|
|
|
if is_stopping.sum() == batch_size: |
|
|
break |
|
|
|
|
|
start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3 |
|
|
start_lengths = seq_len - start_indices |
|
|
|
|
|
output = [] |
|
|
for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids): |
|
|
output.append((start_length, cur_generation_ids[start_idx:])) |
|
|
|
|
|
return output |
|
|
|