Spaces:
Build error
Build error
| import copy | |
| from typing import Dict, Sequence | |
| import torch | |
| import transformers | |
| from llava.constants import ( | |
| IGNORE_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| IMAGE_TOKEN_INDEX, | |
| ) | |
| from llava import conversation as conversation_lib | |
| from llava.mm_utils import tokenizer_image_token | |
| def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: | |
| """Tokenize a list of strings.""" | |
| tokenized_list = [ | |
| tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ) | |
| for text in strings | |
| ] | |
| input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] | |
| input_ids_lens = labels_lens = [ | |
| tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list | |
| ] | |
| return dict( | |
| input_ids=input_ids, | |
| labels=labels, | |
| input_ids_lens=input_ids_lens, | |
| labels_lens=labels_lens, | |
| ) | |
| def _mask_targets(target, tokenized_lens, speakers): | |
| # cur_idx = 0 | |
| cur_idx = tokenized_lens[0] | |
| tokenized_lens = tokenized_lens[1:] | |
| target[:cur_idx] = IGNORE_INDEX | |
| for tokenized_len, speaker in zip(tokenized_lens, speakers): | |
| if speaker == "human": | |
| target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX | |
| cur_idx += tokenized_len | |
| def _add_speaker_and_signal(header, source, get_conversation=True): | |
| """Add speaker and start/end signal on each round.""" | |
| BEGIN_SIGNAL = "### " | |
| END_SIGNAL = "\n" | |
| conversation = header | |
| for sentence in source: | |
| from_str = sentence["from"] | |
| if from_str.lower() == "human": | |
| from_str = conversation_lib.default_conversation.roles[0] | |
| elif from_str.lower() == "gpt": | |
| from_str = conversation_lib.default_conversation.roles[1] | |
| else: | |
| from_str = "unknown" | |
| sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL | |
| if get_conversation: | |
| conversation += sentence["value"] | |
| conversation += BEGIN_SIGNAL | |
| return conversation | |
| def preprocess_multimodal( | |
| sources: Sequence[str], is_multimodal: bool, mm_use_im_start_end: bool | |
| ) -> Dict: | |
| if not is_multimodal: | |
| return sources | |
| for source in sources: | |
| for sentence in source: | |
| if DEFAULT_IMAGE_TOKEN in sentence["value"]: | |
| sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() | |
| sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] | |
| sentence["value"] = sentence["value"].strip() | |
| if "mmtag" in conversation_lib.default_conversation.version: | |
| sentence["value"] = sentence["value"].replace( | |
| DEFAULT_IMAGE_TOKEN, | |
| "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>", | |
| ) | |
| replace_token = DEFAULT_IMAGE_TOKEN | |
| if mm_use_im_start_end: | |
| replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN | |
| sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
| return sources | |
| def preprocess_llama_2( | |
| sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False | |
| ) -> Dict: | |
| conv = conversation_lib.default_conversation.copy() | |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]["from"]] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence["from"]] | |
| assert role == conv.roles[j % 2], f"{i}" | |
| conv.append_message(role, sentence["value"]) | |
| conversations.append(conv.get_prompt()) | |
| # Tokenize conversations | |
| if has_image: | |
| input_ids = torch.stack( | |
| [ | |
| tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
| for prompt in conversations | |
| ], | |
| dim=0, | |
| ) | |
| else: | |
| input_ids = tokenizer( | |
| conversations, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ).input_ids | |
| targets = input_ids.clone() | |
| assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 | |
| # Mask targets | |
| sep = "[/INST] " | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
| rounds = conversation.split(conv.sep2) | |
| cur_len = 1 | |
| target[:cur_len] = IGNORE_INDEX | |
| for i, rou in enumerate(rounds): | |
| if rou == "": | |
| break | |
| parts = rou.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| if has_image: | |
| round_len = len(tokenizer_image_token(rou, tokenizer)) | |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 | |
| else: | |
| round_len = len(tokenizer(rou).input_ids) | |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
| cur_len += round_len | |
| target[cur_len:] = IGNORE_INDEX | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_INDEX | |
| print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| ) | |
| def preprocess_llama_2_obj_identifier( | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| obj_dict: Dict[str, Dict], | |
| obj_context_feature_type: str, | |
| mode: str, | |
| ) -> Dict: | |
| """This function tokenizes the conversation into the following format: | |
| %%%% Object-centric context: <obj_0>: <obj_0_feat>; <obj_1>: <obj_1_feat>; ... <obj_i>: <obj_i_feat>;%%%%" | |
| where <obj_i_feat> is currently placeholered by IMAGE_TOKEN_INDEX | |
| but will later be replaced by the actual feature in vector form. | |
| We mark all string tokens as not trainable, only keep the feature vectors trainable. | |
| Args: | |
| sources (_type_): the conversation sources | |
| tokenizer (transformers.PreTrainedTokenizer): the tokenizer | |
| obj_dict (Dict[str, Dict]): the object dictionary for the scene | |
| obj_context_feature_type (str): the type of object feature to use for the object context | |
| Returns: | |
| Dict: the tokenized input_ids and labels | |
| """ | |
| conv = conversation_lib.conv_llava_llama_2_obj_identifier.copy() | |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]["from"]] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence["from"]] | |
| assert role == conv.roles[j % 2], f"{i}" | |
| conv.append_message(role, sentence["value"]) | |
| conversations.append(conv.get_prompt()) | |
| # Tokenize conversations | |
| input_ids = torch.stack( | |
| [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], | |
| dim=0, | |
| ) | |
| targets = input_ids.clone() | |
| assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 | |
| # Mask targets | |
| sep = "[/INST] " | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
| rounds = conversation.split(conv.sep2) | |
| cur_len = 1 | |
| target[:cur_len] = IGNORE_INDEX | |
| for i, rou in enumerate(rounds): | |
| if rou == "": | |
| break | |
| parts = rou.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| round_len = len(tokenizer_image_token(rou, tokenizer)) | |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 | |
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
| cur_len += round_len | |
| target[cur_len:] = IGNORE_INDEX | |
| if ( | |
| cur_len < tokenizer.model_max_length and mode != "generate" | |
| ): # check if target is correctly masked. when generating, we don't have any target | |
| if cur_len != total_len: | |
| target[:] = IGNORE_INDEX | |
| print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
| assert ( | |
| input_ids.shape[0] == targets.shape[0] == 1 | |
| ), "Only support tokenization for one conversation at a time" | |
| input_id = input_ids[0] | |
| targets = targets[0] | |
| # TODO: replace -200 (IMAGE_TOKEN_INDEX) with object identifier tokens, | |
| # we want the LLM to see: | |
| # %%%% Object-centric context: <obj_0>: <obj_0_feat>; <obj_1>: <obj_1_feat>; ... <obj_i>: <obj_i_feat>;%%%%" | |
| # where <obj_i_feat> will later be replaced by the actual feature in vector form, | |
| if obj_context_feature_type == "vector": | |
| obj_context = "%%%% Object-centric context:" | |
| for obj_id, obj_info in obj_dict.items(): | |
| obj_context += f" <{obj_id}>: {tokenizer.sep_token};" # use </s> as a placeholder, later it will be replaced by the actual feature vector | |
| obj_context += "%%%%" | |
| tokenized_obj_context = tokenizer(obj_context).input_ids[ | |
| 1:-1 | |
| ] # strip the bos and eos tokens | |
| tokenized_obj_context = torch.tensor(tokenized_obj_context, dtype=torch.long) | |
| tokenized_obj_context[tokenized_obj_context == tokenizer.sep_token_id] = ( | |
| IMAGE_TOKEN_INDEX # replace </s> with IMAGE_TOKEN_INDEX, so that later we can use -200 to find where the feature vector should be inserted | |
| ) | |
| # mark all string tokens as not trainable, only keep the feature vectors trainable | |
| tokenized_obj_context_target = tokenized_obj_context.clone() | |
| tokenized_obj_context_target[tokenized_obj_context != IMAGE_TOKEN_INDEX] = IGNORE_INDEX | |
| elif obj_context_feature_type == "text": | |
| obj_context = "%%%% Object-centric context:" | |
| for obj_id, obj_info in obj_dict.items(): | |
| obj_context += f" <{obj_id}>: {obj_info};" | |
| obj_context += "%%%%" | |
| tokenized_obj_context = tokenizer(obj_context).input_ids[ | |
| 1:-1 | |
| ] # strip the bos and eos tokens | |
| tokenized_obj_context = torch.tensor(tokenized_obj_context, dtype=torch.long) | |
| tokenized_obj_context_target = tokenized_obj_context.clone() | |
| tokenized_obj_context_target[:] = IGNORE_INDEX # mark all tokens as not trainable | |
| # now, insert the object context into input_id and target, where the IMAGE_TOKEN_INDEX is | |
| separation_idx = torch.where(input_id == IMAGE_TOKEN_INDEX)[0] | |
| input_id_with_obj_context = torch.cat( | |
| [input_id[:separation_idx], tokenized_obj_context, input_id[separation_idx + 1 :]] | |
| ) | |
| target_with_obj_context = torch.cat( | |
| [ | |
| targets[:separation_idx], | |
| tokenized_obj_context_target, | |
| targets[separation_idx + 1 :], | |
| ] | |
| ) | |
| if obj_context_feature_type == "vector": | |
| return dict( | |
| input_ids=input_id_with_obj_context, | |
| labels=target_with_obj_context, | |
| obj_dict=obj_dict, # return the object dictionary so that we can later embed the features | |
| ) | |
| elif obj_context_feature_type == "text": | |
| return dict(input_ids=input_id_with_obj_context, labels=target_with_obj_context) | |
| def preprocess_v1( | |
| sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False | |
| ) -> Dict: | |
| conv = conversation_lib.default_conversation.copy() | |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]["from"]] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence["from"]] | |
| assert role == conv.roles[j % 2], f"{i}" | |
| conv.append_message(role, sentence["value"]) | |
| conversations.append(conv.get_prompt()) | |
| # Tokenize conversations | |
| if has_image: | |
| input_ids = torch.stack( | |
| [ | |
| tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
| for prompt in conversations | |
| ], | |
| dim=0, | |
| ) | |
| else: | |
| input_ids = tokenizer( | |
| conversations, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ).input_ids | |
| targets = input_ids.clone() | |
| assert conv.sep_style == conversation_lib.SeparatorStyle.TWO | |
| # Mask targets | |
| sep = conv.sep + conv.roles[1] + ": " | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
| rounds = conversation.split(conv.sep2) | |
| cur_len = 1 | |
| target[:cur_len] = IGNORE_INDEX | |
| for i, rou in enumerate(rounds): | |
| if rou == "": | |
| break | |
| parts = rou.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| if has_image: | |
| round_len = len(tokenizer_image_token(rou, tokenizer)) | |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 | |
| else: | |
| round_len = len(tokenizer(rou).input_ids) | |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
| cur_len += round_len | |
| target[cur_len:] = IGNORE_INDEX | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_INDEX | |
| print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| ) | |
| def preprocess_mpt( | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| ) -> Dict: | |
| conv = conversation_lib.default_conversation.copy() | |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]["from"]] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence["from"]] | |
| assert role == conv.roles[j % 2], f"{i}" | |
| conv.append_message(role, sentence["value"]) | |
| conversations.append(conv.get_prompt()) | |
| # Tokenize conversations | |
| input_ids = torch.stack( | |
| [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], | |
| dim=0, | |
| ) | |
| targets = input_ids.clone() | |
| assert conv.sep_style == conversation_lib.SeparatorStyle.MPT | |
| # Mask targets | |
| sep = conv.sep + conv.roles[1] | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
| rounds = conversation.split(conv.sep) | |
| re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt | |
| for conv_idx in range(3, len(rounds), 2): | |
| re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt | |
| cur_len = 0 | |
| target[:cur_len] = IGNORE_INDEX | |
| for i, rou in enumerate(re_rounds): | |
| if rou == "": | |
| break | |
| parts = rou.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| round_len = len(tokenizer_image_token(rou, tokenizer)) + len( | |
| tokenizer_image_token(conv.sep, tokenizer) | |
| ) | |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) | |
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
| cur_len += round_len | |
| target[cur_len:] = IGNORE_INDEX | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_INDEX | |
| print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| ) | |
| def preprocess_plain( | |
| sources: Sequence[str], | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| ) -> Dict: | |
| # add end signal and concatenate together | |
| conversations = [] | |
| for source in sources: | |
| assert len(source) == 2 | |
| assert DEFAULT_IMAGE_TOKEN in source[0]["value"] | |
| source[0]["value"] = DEFAULT_IMAGE_TOKEN | |
| conversation = ( | |
| source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep | |
| ) | |
| conversations.append(conversation) | |
| # tokenize conversations | |
| input_ids = [ | |
| tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations | |
| ] | |
| targets = copy.deepcopy(input_ids) | |
| for target, source in zip(targets, sources): | |
| tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) | |
| target[:tokenized_len] = IGNORE_INDEX | |
| return dict(input_ids=input_ids, labels=targets) | |
| def preprocess( | |
| sources: Sequence[str], | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| has_image: bool = False, | |
| ) -> Dict: | |
| """ | |
| Given a list of sources, each is a conversation list. This transform: | |
| 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; | |
| 2. Concatenate conversations together; | |
| 3. Tokenize the concatenated conversation; | |
| 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. | |
| """ | |
| if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: | |
| return preprocess_plain(sources, tokenizer) | |
| if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: | |
| return preprocess_llama_2(sources, tokenizer, has_image=has_image) | |
| if conversation_lib.default_conversation.version.startswith("v1"): | |
| return preprocess_v1(sources, tokenizer, has_image=has_image) | |
| if conversation_lib.default_conversation.version == "mpt": | |
| return preprocess_mpt(sources, tokenizer) | |
| # add end signal and concatenate together | |
| conversations = [] | |
| for source in sources: | |
| header = f"{conversation_lib.default_conversation.system}\n\n" | |
| conversation = _add_speaker_and_signal(header, source) | |
| conversations.append(conversation) | |
| # tokenize conversations | |
| def get_tokenize_len(prompts): | |
| return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] | |
| if has_image: | |
| input_ids = [ | |
| tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
| for prompt in conversations | |
| ] | |
| else: | |
| conversations_tokenized = _tokenize_fn(conversations, tokenizer) | |
| input_ids = conversations_tokenized["input_ids"] | |
| targets = copy.deepcopy(input_ids) | |
| for target, source in zip(targets, sources): | |
| if has_image: | |
| tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) | |
| else: | |
| tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)[ | |
| "input_ids_lens" | |
| ] | |
| speakers = [sentence["from"] for sentence in source] | |
| _mask_targets(target, tokenized_lens, speakers) | |
| return dict(input_ids=input_ids, labels=targets) | |