|
|
from enum import auto, Enum |
|
|
import copy |
|
|
import dataclasses |
|
|
from util.easydict import EasyDict |
|
|
from typing import Any, List |
|
|
|
|
|
class SeparatorStyle(Enum): |
|
|
"""Different separator styles.""" |
|
|
SINGLE = auto() |
|
|
TWO = auto() |
|
|
MPT = auto() |
|
|
|
|
|
class MultiModalConvStyle(Enum): |
|
|
"""Different multi-modal conversation styles.""" |
|
|
MM_ALONE = 'mm_alone' |
|
|
MM_INTERLEAF = 'mm_inferleaf' |
|
|
|
|
|
@dataclasses.dataclass |
|
|
class Conversation(EasyDict): |
|
|
"""A class to manage all conversation history.""" |
|
|
system: str |
|
|
roles: List[str] |
|
|
messages: List[List[str]] |
|
|
sep: List[str] |
|
|
mm_token: str |
|
|
|
|
|
|
|
|
mm_style: MultiModalConvStyle = MultiModalConvStyle.MM_INTERLEAF |
|
|
pre_query_prompt: str = None |
|
|
post_query_prompt: str = None |
|
|
pre_answer_prompt: str = None |
|
|
post_answer_prompt: str = None |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
"""Initialize the conversation. Ensures separators are properly set.""" |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
if isinstance(self.sep, str): |
|
|
self.sep = [self.sep for _ in self.roles] |
|
|
|
|
|
def get_prompt(self): |
|
|
"""Constructs the prompt for the conversation based on the current history.""" |
|
|
|
|
|
sep = dict(zip(self.roles, self.sep)) |
|
|
|
|
|
ret = self.system + sep[self.roles[0]] if self.system != "" else "" |
|
|
|
|
|
|
|
|
for i, (role, message) in enumerate(self.messages): |
|
|
if i + 1 == len(self.messages): |
|
|
ret += role + message |
|
|
else: |
|
|
ret += role + message + sep[role] |
|
|
|
|
|
return ret |
|
|
|
|
|
def user_query(self, query=None, pre_query_prompt=None, post_query_prompt=None, is_mm=False, num_mm_token=1): |
|
|
"""Append the user's query to the conversation.""" |
|
|
|
|
|
if post_query_prompt is not None: |
|
|
query = f"{query}{post_query_prompt}" |
|
|
if pre_query_prompt is not None: |
|
|
query = f"{pre_query_prompt}{query}" |
|
|
|
|
|
role = self.roles[0] |
|
|
|
|
|
|
|
|
if is_mm: |
|
|
mm_str = num_mm_token * self.mm_token[:-1] + self.mm_token[-1] |
|
|
if self.mm_style == MultiModalConvStyle.MM_ALONE: |
|
|
self._append_message(role, mm_str) |
|
|
elif self.mm_style == MultiModalConvStyle.MM_INTERLEAF: |
|
|
if self.mm_token not in query: |
|
|
query = f'{mm_str}{query}' |
|
|
self._append_message(role, query) |
|
|
|
|
|
def assistant_response(self, response, pre_answer_prompt=None, post_answer_prompt=None): |
|
|
"""Append the assistant's response to the conversation.""" |
|
|
|
|
|
if pre_answer_prompt is not None: |
|
|
response = f"{pre_answer_prompt}{response}" |
|
|
|
|
|
if post_answer_prompt is not None: |
|
|
response = f"{response}{post_answer_prompt}" |
|
|
|
|
|
role = self.roles[1] |
|
|
self._append_message(role, response) |
|
|
|
|
|
def _append_message(self, role, message): |
|
|
"""Helper function to append a message to the conversation.""" |
|
|
message = '' if message is None else message |
|
|
self.messages.append([role, message]) |
|
|
|
|
|
def copy(self): |
|
|
"""Return a deep copy of the conversation.""" |
|
|
return copy.deepcopy(self) |
|
|
|
|
|
|
|
|
llava_next_video_template = Conversation( |
|
|
system="", |
|
|
roles=("USER: ","ASSISTANT:"), |
|
|
messages=[], |
|
|
sep=[" ",""], |
|
|
mm_token='<video>\n', |
|
|
mm_style=MultiModalConvStyle.MM_INTERLEAF |
|
|
) |
|
|
|
|
|
llava_next_video_template_with_space_after_assistant = Conversation( |
|
|
system="", |
|
|
roles=("USER: ","ASSISTANT: "), |
|
|
messages=[], |
|
|
sep=[" ",""], |
|
|
mm_token='<video>\n', |
|
|
mm_style=MultiModalConvStyle.MM_INTERLEAF |
|
|
) |
|
|
|
|
|
llava_next_video_template_with_space_after_user = Conversation( |
|
|
system="", |
|
|
roles=("USER: ","ASSISTANT:"), |
|
|
messages=[], |
|
|
sep=[" ",""], |
|
|
mm_token='<video>\n', |
|
|
mm_style=MultiModalConvStyle.MM_INTERLEAF |
|
|
) |
|
|
|
|
|
ConversationTemplates = { |
|
|
"llava_next_video_template": llava_next_video_template, |
|
|
"llava_next_video_template_with_space_after_assistant": llava_next_video_template_with_space_after_assistant, |
|
|
"llava_next_video_template_with_space_after_user" : llava_next_video_template_with_space_after_user |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
example = Conversation( |
|
|
system="You are Video-ChatGPT, a large vision-language assistant. " |
|
|
"You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. " |
|
|
"Follow the instructions carefully and explain your answers in detail based on the provided video.", |
|
|
roles=("USER:", "ASSISTANT:"), |
|
|
messages=[], |
|
|
sep=["</user>", "</assi>"], |
|
|
mm_token='<image>', |
|
|
mm_style=MultiModalConvStyle.MM_INTERLEAF, |
|
|
) |
|
|
|
|
|
|
|
|
print("Test 1 - Empty Conversation:") |
|
|
print(example.get_prompt()) |
|
|
|
|
|
|
|
|
print("\nTest 2 - User Query Added:") |
|
|
example.user_query("What is the content of the video?", pre_query_prompt="Please describe") |
|
|
print(example.get_prompt()) |
|
|
|
|
|
|
|
|
print("\nTest 3 - Assistant Response Added:") |
|
|
example.assistant_response("The video shows a person walking a dog in a park.") |
|
|
print(example.get_prompt()) |
|
|
|
|
|
|
|
|
print("\nTest 4 - Multi-modal Query:") |
|
|
example.user_query("Can you analyze the object?", is_mm=True) |
|
|
print(example.get_prompt()) |