| |
|
|
| """ |
| Define conversation format for each training phases and language models. |
| |
| Modified from LLaVA codebase: https://github.com/haotian-liu/LLaVA/blob/main/llava/conversation.py |
| |
| NOTE: |
| - an example of required json format is: |
| data = { |
| "image": IMAGE_PATH, or "images": LIST of IMAGE_PATH, |
| "conversations": [ |
| {"from": "human", "value": "hello"}, |
| {"from": "assistant", "value": "Hi, how can I help you today?"}, |
| {"from": "human", "value": "Who are you?"}, |
| {"from": "assistant", "value": "I am a multimodal large language model created by FAIR. I can assist you with questions related to images and videos."}, |
| ] |
| } |
| """ |
|
|
| import copy |
| from dataclasses import dataclass |
| from typing import Callable, Dict, List, Union |
|
|
|
|
| @dataclass |
| class Conversation: |
| system: str |
| conversations: list |
| bos_token: str |
| sep_system: str |
| sep_question: str |
| sep_answer: str |
| place_image_token: Callable |
| image_token: str = "<|image|>" |
| pre_system: str = "" |
| pre_question: str = "" |
| pre_answer: str = "" |
| eos_token: str = "" |
|
|
| |
| def get_conversation_dict_list( |
| self, num_images: int = 1, num_patches: int = 144, media_type: str = "image" |
| ) -> List[Dict]: |
| """ |
| Each turn of conversation is a dict with source and target keys. |
| """ |
| conv_dict_list = [] |
| sys_text = self.pre_system + self.system + self.sep_system |
| is_first = True |
|
|
| if media_type == "multi_image": |
| |
| |
| for conversation in self.conversations: |
| if conversation["from"] == "human": |
| conversation["value"] = conversation["value"].replace( |
| "<image>", self.image_token * num_patches |
| ) |
| else: |
| |
| self.conversations[0]["value"] = ( |
| self.conversations[0]["value"] |
| .replace("<image>\n", "") |
| .replace("\n<image>", "") |
| .replace("<image>", "") |
| .replace("<video>\n", "") |
| .replace("\n<video>", "") |
| .replace("<video>", "") |
| ) |
|
|
| self.conversations[0]["value"] = self.place_image_token( |
| self.conversations[0]["value"], |
| self.image_token, |
| num_images * num_patches, |
| ) |
| for conv in self.conversations: |
| if is_first and conv["from"] == "assistant": |
| continue |
| if conv["from"] == "human": |
| conv_text = "" |
| if is_first: |
| conv_text += sys_text |
| conv_text += self.pre_question + conv["value"] + self.sep_question |
| conv_dict = {"user": conv_text} |
| is_first = False |
| elif conv["from"] == "assistant": |
| conv_text = self.pre_answer + str(conv["value"]) + self.sep_answer |
| conv_dict.update({"assistant": conv_text}) |
| conv_dict_list.append(conv_dict) |
| else: |
| raise ValueError( |
| f"conv['from'] must be human or assistant, but got {conv['from']}." |
| "Please fix your jsonl file." |
| ) |
| |
| conv_dict_list[0]["user"] = f"{self.bos_token}{conv_dict_list[0]['user']}" |
| conv_dict_list[-1][ |
| "assistant" |
| ] = f"{conv_dict_list[-1]['assistant']}{self.eos_token}" |
|
|
| return conv_dict_list |
|
|
| def get_generation_prompt( |
| self, prompt: str, num_images: int = 1, num_patches: int = 144 |
| ): |
| if prompt.count("<image>") == num_images: |
| prompt = prompt.replace("<image>", self.image_token * num_patches) |
| else: |
| prompt = ( |
| prompt.replace("<image>\n", "") |
| .replace("\n<image>", "") |
| .replace("<image>", "") |
| .replace("<video>\n", "") |
| .replace("\n<video>", "") |
| .replace("<video>", "") |
| ) |
| prompt = self.place_image_token( |
| prompt, |
| self.image_token, |
| num_images * num_patches, |
| ) |
|
|
| sys_text = self.bos_token + self.pre_system + self.system + self.sep_system |
| return ( |
| sys_text + self.pre_question + prompt + self.sep_question + self.pre_answer |
| ) |
|
|
| def add_conv(self, conv: Union[List, Dict]): |
| if isinstance(conv, list): |
| self.conversations.extend(conv) |
| elif isinstance(conv, dict): |
| self.conversations.append(conv) |
| else: |
| raise ValueError(f"conv must be a list or dict, but got {type(conv)}") |
|
|
| def copy(self): |
| return Conversation( |
| system=self.system, |
| conversations=copy.deepcopy(self.conversations), |
| place_image_token=self.place_image_token, |
| bos_token=self.bos_token, |
| sep_system=self.sep_system, |
| sep_question=self.sep_question, |
| sep_answer=self.sep_answer, |
| pre_system=self.pre_system, |
| pre_question=self.pre_question, |
| pre_answer=self.pre_answer, |
| image_token=self.image_token, |
| eos_token=self.eos_token, |
| ) |
|
|
|
|
| conv_warmup = Conversation( |
| system="", |
| conversations=[], |
| place_image_token=lambda text, image_token, num_image_tokens: image_token |
| * num_image_tokens, |
| bos_token="", |
| pre_system="", |
| pre_question="", |
| pre_answer="", |
| sep_system="", |
| sep_question="", |
| sep_answer="\n", |
| eos_token="", |
| image_token="<|image|>", |
| ) |
|
|
| conv_plm_sft = Conversation( |
| system="You are a helpful language and vision assistant. " |
| "You are able to understand the visual content that the user provides, " |
| "and assist the user with a variety of tasks using natural language.", |
| conversations=[], |
| place_image_token=lambda text, image_token, num_image_tokens: ( |
| image_token * num_image_tokens |
| ) |
| + text, |
| bos_token="<|begin_of_text|>", |
| pre_system="<|start_header_id|>system<|end_header_id|>\n\n", |
| pre_question="<|start_header_id|>user<|end_header_id|>\n\n", |
| pre_answer="<|start_header_id|>assistant<|end_header_id|>\n\n", |
| sep_system="<|eot_id|>", |
| sep_question="<|eot_id|>", |
| sep_answer="<|eot_id|>", |
| eos_token="<|end_of_text|>", |
| image_token="<|image|>", |
| ) |
|
|
|
|
| REGISTERED_CONVS = { |
| "warmup": conv_warmup, |
| "plm_sft": conv_plm_sft, |
| } |
|
|