video_llm_template / util /conversation.py
RoadQAQ's picture
Upload folder using huggingface_hub
710b71f verified
from enum import auto, Enum
import copy
import dataclasses
from util.easydict import EasyDict
from typing import Any, List
class SeparatorStyle(Enum):
"""Different separator styles."""
SINGLE = auto() # Single separator style
TWO = auto() # Double separator style
MPT = auto() # Multi-part separator style
class MultiModalConvStyle(Enum):
"""Different multi-modal conversation styles."""
MM_ALONE = 'mm_alone' # Multi-modal token alone (appears separately)
MM_INTERLEAF = 'mm_inferleaf' # Multi-modal token interleaved with text
@dataclasses.dataclass
class Conversation(EasyDict):
"""A class to manage all conversation history."""
system: str # System prompt
roles: List[str] # List of roles in the conversation (e.g., "User", "Assistant")
messages: List[List[str]] # History of conversation messages
sep: List[str] # Separators for different roles
mm_token: str # Multi-modal token to indicate media elements like images
# Default values for multi-modal settings and additional prompts
mm_style: MultiModalConvStyle = MultiModalConvStyle.MM_INTERLEAF
pre_query_prompt: str = None # Optional prefix prompt before the user's query
post_query_prompt: str = None # Optional suffix prompt after the user's query
pre_answer_prompt: str = None # Optional prompt before the assistant's response
post_answer_prompt: str = None
def __init__(self, *args, **kwargs):
"""Initialize the conversation. Ensures separators are properly set."""
super().__init__(*args, **kwargs)
# If separator is a string, apply the same separator for all roles
if isinstance(self.sep, str):
self.sep = [self.sep for _ in self.roles]
def get_prompt(self):
"""Constructs the prompt for the conversation based on the current history."""
# If the separator is a string, create a dictionary that maps roles to separators
sep = dict(zip(self.roles, self.sep))
# Start the prompt with the system message if available
ret = self.system + sep[self.roles[0]] if self.system != "" else ""
# Loop through the messages and format the prompt
for i, (role, message) in enumerate(self.messages):
if i + 1 == len(self.messages): # For the last message
ret += role + message # Just add role and message
else:
ret += role + message + sep[role] # Add role, message, and separator for other messages
return ret
def user_query(self, query=None, pre_query_prompt=None, post_query_prompt=None, is_mm=False, num_mm_token=1):
"""Append the user's query to the conversation."""
# Add post-query and pre-query prompts if they exist
if post_query_prompt is not None:
query = f"{query}{post_query_prompt}"
if pre_query_prompt is not None:
query = f"{pre_query_prompt}{query}"
role = self.roles[0] # User role
# Handle multi-modal input
if is_mm:
mm_str = num_mm_token * self.mm_token[:-1] + self.mm_token[-1]
if self.mm_style == MultiModalConvStyle.MM_ALONE:
self._append_message(role, mm_str) # MM token as a standalone message
elif self.mm_style == MultiModalConvStyle.MM_INTERLEAF:
if self.mm_token not in query:
query = f'{mm_str}{query}' # Interleave MM token with query
self._append_message(role, query)
def assistant_response(self, response, pre_answer_prompt=None, post_answer_prompt=None):
"""Append the assistant's response to the conversation."""
# Add post-query and pre-query prompts if they exist
if pre_answer_prompt is not None:
response = f"{pre_answer_prompt}{response}"
if post_answer_prompt is not None:
response = f"{response}{post_answer_prompt}"
role = self.roles[1] # Assistant role
self._append_message(role, response)
def _append_message(self, role, message):
"""Helper function to append a message to the conversation."""
message = '' if message is None else message # Handle None messages
self.messages.append([role, message])
def copy(self):
"""Return a deep copy of the conversation."""
return copy.deepcopy(self)
llava_next_video_template = Conversation(
system="",
roles=("USER: ","ASSISTANT:"),
messages=[],
sep=[" ",""],
mm_token='<video>\n',
mm_style=MultiModalConvStyle.MM_INTERLEAF
)
llava_next_video_template_with_space_after_assistant = Conversation(
system="",
roles=("USER: ","ASSISTANT: "),
messages=[],
sep=[" ",""],
mm_token='<video>\n',
mm_style=MultiModalConvStyle.MM_INTERLEAF
)
llava_next_video_template_with_space_after_user = Conversation(
system="",
roles=("USER: ","ASSISTANT:"),
messages=[],
sep=[" ",""],
mm_token='<video>\n',
mm_style=MultiModalConvStyle.MM_INTERLEAF
)
ConversationTemplates = {
"llava_next_video_template": llava_next_video_template,
"llava_next_video_template_with_space_after_assistant": llava_next_video_template_with_space_after_assistant,
"llava_next_video_template_with_space_after_user" : llava_next_video_template_with_space_after_user
}
# Test cases to ensure the code works as expected
if __name__ == '__main__':
# Example instance of Conversation class for a Video-ChatGPT system
example = Conversation(
system="You are Video-ChatGPT, a large vision-language assistant. "
"You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. "
"Follow the instructions carefully and explain your answers in detail based on the provided video.",
roles=("USER:", "ASSISTANT:"), # Define roles for conversation
messages=[], # Initially, the conversation has no messages
sep=["</user>", "</assi>"], # Separator for the user and assistant messages
mm_token='<image>', # Token for multimedia (image in this case)
mm_style=MultiModalConvStyle.MM_INTERLEAF, # Multimedia style for interleaving with text
)
# Test 1: Basic prompt generation without any messages
print("Test 1 - Empty Conversation:")
print(example.get_prompt())
# Test 2: Adding a user query and generating the prompt
print("\nTest 2 - User Query Added:")
example.user_query("What is the content of the video?", pre_query_prompt="Please describe")
print(example.get_prompt())
# Test 3: Adding an assistant response and generating the prompt
print("\nTest 3 - Assistant Response Added:")
example.assistant_response("The video shows a person walking a dog in a park.")
print(example.get_prompt())
# Test 4: Multi-modal interaction (adding an image token)
print("\nTest 4 - Multi-modal Query:")
example.user_query("Can you analyze the object?", is_mm=True)
print(example.get_prompt())