| import os |
| import gc |
| import copy |
| import time |
|
|
| import torch |
| import warnings |
| import transformers |
|
|
| import numpy as np |
|
|
| from typing import Dict, Optional, Sequence |
| from omnilmm import conversation as conversation_lib |
|
|
| IGNORE_INDEX = -100 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
| DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| DEFAULT_IM_START_TOKEN = "<im_start>" |
| DEFAULT_IM_END_TOKEN = "<im_end>" |
|
|
|
|
| def _tokenize_fn(strings: Sequence[str], |
| tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
| """Tokenize a list of strings.""" |
| tokenized_list = [ |
| tokenizer( |
| text, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ) for text in strings |
| ] |
| input_ids = labels = [ |
| tokenized.input_ids[0] for tokenized in tokenized_list |
| ] |
| input_ids_lens = labels_lens = [ |
| tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
| for tokenized in tokenized_list |
| ] |
| return dict( |
| input_ids=input_ids, |
| labels=labels, |
| input_ids_lens=input_ids_lens, |
| labels_lens=labels_lens, |
| ) |
|
|
|
|
|
|
| def omni_preprocess(sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| generation=False): |
| system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.' |
| ignore_index = -100 |
|
|
| response_template = '\n<|assistant|>\n' |
| instruction_template = '\n<|user|>\n' |
| response_token_ids = tokenizer.encode( |
| response_template, add_special_tokens=False) |
| instruction_token_ids = tokenizer.encode( |
| instruction_template, add_special_tokens=False) |
|
|
| batch_input_ids = [] |
| batch_labels = [] |
| for i in range(len(sources)): |
| new_source = [] |
| prev_role = 'unexpect' |
| for conv_turn in sources[i]: |
| role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role'] |
| content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content'] |
|
|
| role = 'user' if role == 'human' else role |
| role = 'assistant' if role == 'gpt' else role |
|
|
| assert role in ['user', 'assistant'] |
| assert role != prev_role, f'role={role}, prev_role={prev_role}' |
| prev_role = role |
|
|
| new_turn = { |
| 'role': role, |
| 'content': content |
| } |
| new_source.append(new_turn) |
| if new_source[0]['role'] != 'system': |
| new_source.insert(0, {'role': 'system', 'content': system_content}) |
|
|
| |
| res_text = tokenizer.apply_chat_template( |
| new_source, tokenize=False, add_generation_prompt=generation) |
| if not generation: |
| res_text = res_text.strip() |
|
|
| conversations_tokenized = _tokenize_fn([res_text], tokenizer) |
| res_input_ids = conversations_tokenized["input_ids"][0] |
|
|
| |
| res_labels = copy.deepcopy(conversations_tokenized["labels"][0]) |
|
|
| response_token_ids_idxs = [] |
| human_token_ids_idxs = [] |
|
|
| for assistant_idx in np.where(res_labels == response_token_ids[0])[0]: |
| |
| if (response_token_ids == res_labels[assistant_idx: assistant_idx + len( |
| response_token_ids)].tolist() |
| ): |
| response_token_ids_idxs.append( |
| assistant_idx + len(response_token_ids)) |
|
|
| if len(response_token_ids_idxs) == 0: |
| warnings.warn( |
| f"Could not find response key `{response_template}` in the " |
| f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' |
| f'Raw text is @===>{res_text}<===@' |
| f'Raw source is @===>{new_source}<===@' |
| f"This instance will be ignored in loss calculation. " |
| f"Note, if this happens often, consider increasing the `max_seq_length`." |
| ) |
| res_labels[:] = ignore_index |
|
|
| human_token_ids = instruction_token_ids |
| for human_idx in np.where(res_labels == human_token_ids[0])[0]: |
| |
| if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist(): |
| human_token_ids_idxs.append(human_idx) |
|
|
| if len(human_token_ids_idxs) == 0: |
| warnings.warn( |
| f"Could not find instruction key `{instruction_template}` in the " |
| f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' |
| f'Raw text is @===>{res_text}<===@' |
| f'Raw source is @===>{new_source}<===@' |
| f"This instance will be ignored in loss calculation. " |
| f"Note, if this happens often, consider increasing the `max_seq_length`." |
| ) |
| res_labels[:] = ignore_index |
|
|
| for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): |
| |
| if idx != 0: |
| res_labels[start:end] = ignore_index |
| else: |
| res_labels[:end] = ignore_index |
|
|
| if len(response_token_ids_idxs) < len(human_token_ids_idxs): |
| res_labels[human_token_ids_idxs[-1]:] = ignore_index |
|
|
| batch_input_ids.append(res_input_ids) |
| batch_labels.append(res_labels) |
|
|
| return dict(input_ids=batch_input_ids, labels=batch_labels) |
|
|
|
|
|
|