|
|
|
|
|
import re |
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union |
|
|
|
|
|
import torch |
|
|
from transformers import PreTrainedTokenizerBase, StoppingCriteria |
|
|
|
|
|
Prompt = List[Union[str, List[int], List[str]]] |
|
|
Word = Union[str, List[int]] |
|
|
Context = Word |
|
|
|
|
|
|
|
|
class ContextType: |
|
|
RESPONSE = 'response' |
|
|
SUFFIX = 'suffix' |
|
|
OTHER = 'other' |
|
|
|
|
|
|
|
|
class StopWordsCriteria(StoppingCriteria): |
|
|
"""Adding extra stop words in template to prevent unstoppable generation |
|
|
Like suffixes and chat seps in the template. |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: List[Word], **tokenizer_kwargs) -> None: |
|
|
self.tokenizer = tokenizer |
|
|
self.stop_words = stop_words |
|
|
self.tokenizer_kwargs = tokenizer_kwargs |
|
|
self.start_idx = -1 |
|
|
self.is_done = None |
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
if self.start_idx == -1: |
|
|
self.start_idx = len(input_ids[0]) - 1 |
|
|
self.is_done = torch.full((input_ids.shape[0], ), False, device=input_ids.device, dtype=torch.bool) |
|
|
|
|
|
|
|
|
start_idx = max(self.start_idx, input_ids.shape[1] - 20) |
|
|
text_list = self.tokenizer.batch_decode(input_ids[:, start_idx:], **self.tokenizer_kwargs) |
|
|
for i, text in enumerate(text_list): |
|
|
if self.is_done[i]: |
|
|
continue |
|
|
is_finished = False |
|
|
for stop_word in self.stop_words: |
|
|
if isinstance(stop_word, str) and stop_word in text or isinstance( |
|
|
stop_word, list) and input_ids[i][-len(stop_word):].tolist() == stop_word: |
|
|
is_finished = True |
|
|
break |
|
|
self.is_done[i] = is_finished |
|
|
return self.is_done |
|
|
|
|
|
|
|
|
def fetch_one(element: Union[Tuple, List, Set, Dict, Any], item_type: Optional[Type] = None) -> Any: |
|
|
if isinstance(element, (tuple, set, list)): |
|
|
for ele in element: |
|
|
out = fetch_one(ele) |
|
|
if out and (item_type is None or isinstance(out, item_type)): |
|
|
return out |
|
|
elif isinstance(element, dict): |
|
|
return fetch_one(list(element.values())) |
|
|
else: |
|
|
return element |
|
|
|
|
|
|
|
|
def findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]: |
|
|
"""Find the index of a token in the token_list.""" |
|
|
if isinstance(sub_token_list, int): |
|
|
sub_token_list = [sub_token_list] |
|
|
res = [] |
|
|
idx = -1 |
|
|
try: |
|
|
while True: |
|
|
idx = token_list.index(sub_token_list[0], idx + 1) |
|
|
if len(sub_token_list) == 1 or sub_token_list == token_list[idx:idx + len(sub_token_list)]: |
|
|
res.append(idx) |
|
|
except ValueError: |
|
|
pass |
|
|
return res |
|
|
|
|
|
|
|
|
def align_image_inputs(input_ids: List[int], labels: List[int], new_input_ids, |
|
|
image_token: int) -> Tuple[List[int], List[int]]: |
|
|
if isinstance(new_input_ids, torch.Tensor): |
|
|
new_input_ids = new_input_ids.tolist() |
|
|
|
|
|
|
|
|
i, j = 0, 0 |
|
|
while i < len(input_ids): |
|
|
x = input_ids[i] |
|
|
if x == image_token: |
|
|
assert i + 1 < len(input_ids), f'input_ids[-10:]: {input_ids[-10:]}' |
|
|
assert i - 1 >= 0, f'input_ids[:10]: {input_ids[:10]}' |
|
|
|
|
|
|
|
|
j_begin = j - 1 |
|
|
for k in range(5): |
|
|
if j_begin + k < len(new_input_ids) and new_input_ids[j_begin + k] == input_ids[i - 1]: |
|
|
j_begin += k |
|
|
break |
|
|
if j_begin - k >= 0 and new_input_ids[j_begin - k] == input_ids[i - 1]: |
|
|
j_begin -= k |
|
|
break |
|
|
else: |
|
|
raise ValueError(f'new_input_ids: {new_input_ids}, input_ids: {input_ids}') |
|
|
j_begin += 1 |
|
|
while j < len(new_input_ids) and new_input_ids[j] != input_ids[i + 1]: |
|
|
j += 1 |
|
|
input_ids = input_ids[:i] + new_input_ids[j_begin:j] + input_ids[i + 1:] |
|
|
if labels: |
|
|
labels = labels[:i] + [-100] * (j - j_begin) + labels[i + 1:] |
|
|
i += j - j_begin |
|
|
else: |
|
|
j += 1 |
|
|
i += 1 |
|
|
return input_ids, labels |
|
|
|
|
|
|
|
|
def _split_str_by_regex(text: str, regex_delimiters: List[str]) -> List[str]: |
|
|
combined_pattern = '|'.join(f'({pattern})' for pattern in regex_delimiters) |
|
|
parts = re.split(combined_pattern, text, flags=re.DOTALL) |
|
|
parts = [part for part in parts if part is not None] |
|
|
if parts[0] == '': |
|
|
parts.pop(0) |
|
|
else: |
|
|
parts.insert(0, '') |
|
|
assert len(parts) % 2 == 0, f'result: {parts}' |
|
|
assert ''.join(parts) == text, f'split_result: {parts}, text: {text}' |
|
|
return parts |
|
|
|
|
|
|
|
|
def split_str_parts_by(text: str, delimiters: List[str], regex_mode: bool = False) -> List[Dict[str, str]]: |
|
|
"""Split the text field into parts. |
|
|
|
|
|
Args: |
|
|
text: A text to be split. |
|
|
delimiters: The delimiters. |
|
|
|
|
|
Returns: |
|
|
The split text in list of dicts. |
|
|
""" |
|
|
assert isinstance(text, str), f'text: {text}' |
|
|
delimiters_origin = delimiters |
|
|
delimiters = [re.escape(delimiter) for delimiter in delimiters] |
|
|
parts = _split_str_by_regex(text, delimiters) if delimiters else ['', text] |
|
|
res = [] |
|
|
if regex_mode: |
|
|
parts = [part for part in parts if part] |
|
|
for part in parts: |
|
|
for delimiter, delimiter_origin in zip(delimiters, delimiters_origin): |
|
|
if re.match(delimiter, part, re.DOTALL): |
|
|
break |
|
|
else: |
|
|
delimiter_origin = '' |
|
|
res.append({'key': delimiter_origin, 'content': part}) |
|
|
else: |
|
|
for key, content in zip(parts[::2], parts[1::2]): |
|
|
res.append({'key': key, 'content': content}) |
|
|
return res |
|
|
|