| | import os |
| | import gc |
| | import copy |
| | import time |
| |
|
| | import torch |
| | import warnings |
| | import transformers |
| |
|
| | import numpy as np |
| |
|
| | from typing import Dict, Optional, Sequence |
| | from omnilmm import conversation as conversation_lib |
| |
|
| | IGNORE_INDEX = -100 |
| | DEFAULT_IMAGE_TOKEN = "<image>" |
| | DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| | DEFAULT_IM_START_TOKEN = "<im_start>" |
| | DEFAULT_IM_END_TOKEN = "<im_end>" |
| |
|
| |
|
| | def _tokenize_fn(strings: Sequence[str], |
| | tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
| | """Tokenize a list of strings.""" |
| | tokenized_list = [ |
| | tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ) for text in strings |
| | ] |
| | input_ids = labels = [ |
| | tokenized.input_ids[0] for tokenized in tokenized_list |
| | ] |
| | input_ids_lens = labels_lens = [ |
| | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
| | for tokenized in tokenized_list |
| | ] |
| | return dict( |
| | input_ids=input_ids, |
| | labels=labels, |
| | input_ids_lens=input_ids_lens, |
| | labels_lens=labels_lens, |
| | ) |
| |
|
| |
|
| |
|
| | def omni_preprocess(sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | generation=False): |
| | system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.' |
| | ignore_index = -100 |
| |
|
| | response_template = '\n<|assistant|>\n' |
| | instruction_template = '\n<|user|>\n' |
| | response_token_ids = tokenizer.encode( |
| | response_template, add_special_tokens=False) |
| | instruction_token_ids = tokenizer.encode( |
| | instruction_template, add_special_tokens=False) |
| |
|
| | batch_input_ids = [] |
| | batch_labels = [] |
| | for i in range(len(sources)): |
| | new_source = [] |
| | prev_role = 'unexpect' |
| | for conv_turn in sources[i]: |
| | role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role'] |
| | content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content'] |
| |
|
| | role = 'user' if role == 'human' else role |
| | role = 'assistant' if role == 'gpt' else role |
| |
|
| | assert role in ['user', 'assistant'] |
| | assert role != prev_role, f'role={role}, prev_role={prev_role}' |
| | prev_role = role |
| |
|
| | new_turn = { |
| | 'role': role, |
| | 'content': content |
| | } |
| | new_source.append(new_turn) |
| | if new_source[0]['role'] != 'system': |
| | new_source.insert(0, {'role': 'system', 'content': system_content}) |
| |
|
| | |
| | res_text = tokenizer.apply_chat_template( |
| | new_source, tokenize=False, add_generation_prompt=generation) |
| | if not generation: |
| | res_text = res_text.strip() |
| |
|
| | conversations_tokenized = _tokenize_fn([res_text], tokenizer) |
| | res_input_ids = conversations_tokenized["input_ids"][0] |
| |
|
| | |
| | res_labels = copy.deepcopy(conversations_tokenized["labels"][0]) |
| |
|
| | response_token_ids_idxs = [] |
| | human_token_ids_idxs = [] |
| |
|
| | for assistant_idx in np.where(res_labels == response_token_ids[0])[0]: |
| | |
| | if (response_token_ids == res_labels[assistant_idx: assistant_idx + len( |
| | response_token_ids)].tolist() |
| | ): |
| | response_token_ids_idxs.append( |
| | assistant_idx + len(response_token_ids)) |
| |
|
| | if len(response_token_ids_idxs) == 0: |
| | warnings.warn( |
| | f"Could not find response key `{response_template}` in the " |
| | f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' |
| | f'Raw text is @===>{res_text}<===@' |
| | f'Raw source is @===>{new_source}<===@' |
| | f"This instance will be ignored in loss calculation. " |
| | f"Note, if this happens often, consider increasing the `max_seq_length`." |
| | ) |
| | res_labels[:] = ignore_index |
| |
|
| | human_token_ids = instruction_token_ids |
| | for human_idx in np.where(res_labels == human_token_ids[0])[0]: |
| | |
| | if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist(): |
| | human_token_ids_idxs.append(human_idx) |
| |
|
| | if len(human_token_ids_idxs) == 0: |
| | warnings.warn( |
| | f"Could not find instruction key `{instruction_template}` in the " |
| | f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' |
| | f'Raw text is @===>{res_text}<===@' |
| | f'Raw source is @===>{new_source}<===@' |
| | f"This instance will be ignored in loss calculation. " |
| | f"Note, if this happens often, consider increasing the `max_seq_length`." |
| | ) |
| | res_labels[:] = ignore_index |
| |
|
| | for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): |
| | |
| | if idx != 0: |
| | res_labels[start:end] = ignore_index |
| | else: |
| | res_labels[:end] = ignore_index |
| |
|
| | if len(response_token_ids_idxs) < len(human_token_ids_idxs): |
| | res_labels[human_token_ids_idxs[-1]:] = ignore_index |
| |
|
| | batch_input_ids.append(res_input_ids) |
| | batch_labels.append(res_labels) |
| |
|
| | return dict(input_ids=batch_input_ids, labels=batch_labels) |
| |
|
| |
|
| |
|