|
|
import torch |
|
|
from typing import Dict, Tuple, List |
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
|
class TensorConfig: |
|
|
pad_token_id: int |
|
|
max_prompt_length: int |
|
|
max_obs_length: int |
|
|
max_start_length: int |
|
|
|
|
|
class TensorHelper: |
|
|
def __init__(self, config: TensorConfig): |
|
|
self.config = config |
|
|
|
|
|
def cut_to_effective_len(self, tensor_dict: Dict[str, torch.Tensor], |
|
|
keys: List[str], cut_left: bool = True) -> Dict[str, torch.Tensor]: |
|
|
"""Cut tensors to their effective length based on attention mask.""" |
|
|
effective_len = tensor_dict['attention_mask'].sum(dim=1).max() |
|
|
result = tensor_dict.copy() |
|
|
|
|
|
for key in keys: |
|
|
if cut_left: |
|
|
result[key] = tensor_dict[key][:, -effective_len:] |
|
|
else: |
|
|
result[key] = tensor_dict[key][:, :effective_len] |
|
|
return result |
|
|
|
|
|
def convert_pad_structure(self, tensor: torch.Tensor, pad_to_left: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Convert padding structure and return sorted tensor with indices.""" |
|
|
mask = tensor != self.config.pad_token_id if pad_to_left else tensor == self.config.pad_token_id |
|
|
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True) |
|
|
return tensor.gather(1, sorted_indices), sorted_indices |
|
|
|
|
|
def create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
"""Create attention mask from input ids.""" |
|
|
return torch.where(input_ids != self.config.pad_token_id, 1, 0) |
|
|
|
|
|
def create_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Create position ids from attention mask.""" |
|
|
return (torch.cumsum(attention_mask, dim=1) - 1) * attention_mask |
|
|
|
|
|
def concatenate_with_padding( |
|
|
self, tensors: List[torch.Tensor], |
|
|
pad_to_left: bool = True |
|
|
)-> torch.Tensor: |
|
|
"""Concatenate tensors and handle padding.""" |
|
|
concatenated = torch.cat(tensors, dim=1) |
|
|
padded_tensor, _ = self.convert_pad_structure(concatenated, pad_to_left) |
|
|
return padded_tensor |
|
|
|
|
|
def example_level_pad( |
|
|
self, responses: torch.Tensor, |
|
|
responses_str: List[str], |
|
|
active_mask: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, List[str]]: |
|
|
assert active_mask.sum() == responses.shape[0] |
|
|
|
|
|
batch_size = active_mask.shape[0] |
|
|
seq_len = responses.shape[1] |
|
|
padded_responses = torch.full( |
|
|
(batch_size, seq_len), self.config.pad_token_id, |
|
|
dtype=responses.dtype, device=responses.device |
|
|
) |
|
|
padded_responses[active_mask] = responses |
|
|
|
|
|
|
|
|
padded_responses_str = [""] * batch_size |
|
|
|
|
|
s = 0 |
|
|
for i, is_active in enumerate(active_mask): |
|
|
if is_active: |
|
|
padded_responses_str[i] = responses_str[s] |
|
|
s += 1 |
|
|
|
|
|
return padded_responses, padded_responses_str |
|
|
|
|
|
def erase_after_first_eos(self, completion_ids: torch.Tensor, eos_token_id: int) -> torch.Tensor: |
|
|
is_eos_mask = (completion_ids == eos_token_id) |
|
|
first_eos_indices = torch.argmax(is_eos_mask.int(), dim=1) |
|
|
seq_len = completion_ids.size(1) |
|
|
col_indices = torch.arange(seq_len, device=completion_ids.device) |
|
|
mask_to_replace = (col_indices > first_eos_indices.unsqueeze(1)) & is_eos_mask.any(dim=1).unsqueeze(1) |
|
|
completion_ids[mask_to_replace] = self.config.pad_token_id |
|
|
|
|
|
return completion_ids |