|
|
import copy |
|
|
import torch |
|
|
import transformers |
|
|
import tokenizers |
|
|
|
|
|
from typing import Dict, Sequence |
|
|
|
|
|
from ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX |
|
|
from ola import conversation as conversation_lib |
|
|
from ola.model import * |
|
|
from ola.arguments import DataArguments |
|
|
from ola.constants import SPEECH_TOKEN_INDEX |
|
|
|
|
|
from packaging import version |
|
|
|
|
|
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') |
|
|
|
|
|
|
|
|
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): |
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')] |
|
|
|
|
|
def insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
|
|
input_ids = [] |
|
|
offset = 0 |
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == 'pt': |
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
|
return input_ids |
|
|
|
|
|
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): |
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] |
|
|
|
|
|
def insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
|
|
input_ids = [] |
|
|
offset = 0 |
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == 'pt': |
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
|
return input_ids |
|
|
|
|
|
def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): |
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')] |
|
|
|
|
|
def insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
|
|
input_ids = [] |
|
|
offset = 0 |
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == 'pt': |
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
|
return input_ids |
|
|
|
|
|
def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): |
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")] |
|
|
|
|
|
def insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
|
|
input_ids = [] |
|
|
offset = 0 |
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
nl_tokens = tokenizer("\n").input_ids |
|
|
special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens] |
|
|
|
|
|
for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == 'pt': |
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
|
return input_ids |
|
|
|
|
|
def preprocess_v1( |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False |
|
|
) -> Dict: |
|
|
conv = conversation_lib.default_conversation.copy() |
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence["from"]] |
|
|
assert role == conv.roles[j % 2], f"{i}" |
|
|
conv.append_message(role, sentence["value"]) |
|
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
|
|
|
|
if has_speech: |
|
|
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
|
|
else: |
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors="pt", |
|
|
padding="longest", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
|
|
|
targets = input_ids.clone() |
|
|
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
|
|
rounds = conversation.split(conv.sep2) |
|
|
cur_len = 1 |
|
|
target[:cur_len] = IGNORE_INDEX |
|
|
for i, rou in enumerate(rounds): |
|
|
if rou == "": |
|
|
break |
|
|
|
|
|
parts = rou.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
if has_speech: |
|
|
round_len = len(tokenizer_speech_token(rou, tokenizer)) |
|
|
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 |
|
|
else: |
|
|
round_len = len(tokenizer(rou).input_ids) |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
|
|
|
|
|
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: |
|
|
round_len -= 1 |
|
|
instruction_len -= 1 |
|
|
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
|
|
cur_len += round_len |
|
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_INDEX |
|
|
print( |
|
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
|
f" (ignored)" |
|
|
) |
|
|
print(f"Debug - Conversation: {conversation[:200]}...") |
|
|
print(f"Debug - Target shape: {target.shape}") |
|
|
print(f"Debug - All labels are IGNORE_INDEX: {(target == IGNORE_INDEX).all().item()}") |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
|
|
|
def preprocess_plain( |
|
|
sources: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
) -> Dict: |
|
|
|
|
|
conversations = [] |
|
|
for source in sources: |
|
|
assert len(source) == 2 |
|
|
assert DEFAULT_SPEECH_TOKEN in source[0]['value'] |
|
|
source[0]['value'] = DEFAULT_SPEECH_TOKEN |
|
|
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep |
|
|
conversations.append(conversation) |
|
|
|
|
|
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
|
|
targets = copy.deepcopy(input_ids) |
|
|
for target, source in zip(targets, sources): |
|
|
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) |
|
|
target[:tokenized_len] = IGNORE_INDEX |
|
|
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
|
|
|
def preprocess( |
|
|
sources: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_speech: bool = False |
|
|
) -> Dict: |
|
|
""" |
|
|
Given a list of sources, each is a conversation list. This transform: |
|
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
|
|
2. Concatenate conversations together; |
|
|
3. Tokenize the concatenated conversation; |
|
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
|
|
""" |
|
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: |
|
|
return preprocess_plain(sources, tokenizer) |
|
|
if conversation_lib.default_conversation.version.startswith("v1"): |
|
|
return preprocess_v1(sources, tokenizer, has_speech=has_speech) |
|
|
raise NotImplementedError |