from enum import auto, Enum import copy import dataclasses from util.easydict import EasyDict from typing import Any, List class SeparatorStyle(Enum): """Different separator styles.""" SINGLE = auto() # Single separator style TWO = auto() # Double separator style MPT = auto() # Multi-part separator style class MultiModalConvStyle(Enum): """Different multi-modal conversation styles.""" MM_ALONE = 'mm_alone' # Multi-modal token alone (appears separately) MM_INTERLEAF = 'mm_inferleaf' # Multi-modal token interleaved with text @dataclasses.dataclass class Conversation(EasyDict): """A class to manage all conversation history.""" system: str # System prompt roles: List[str] # List of roles in the conversation (e.g., "User", "Assistant") messages: List[List[str]] # History of conversation messages sep: List[str] # Separators for different roles mm_token: str # Multi-modal token to indicate media elements like images # Default values for multi-modal settings and additional prompts mm_style: MultiModalConvStyle = MultiModalConvStyle.MM_INTERLEAF pre_query_prompt: str = None # Optional prefix prompt before the user's query post_query_prompt: str = None # Optional suffix prompt after the user's query pre_answer_prompt: str = None # Optional prompt before the assistant's response post_answer_prompt: str = None def __init__(self, *args, **kwargs): """Initialize the conversation. Ensures separators are properly set.""" super().__init__(*args, **kwargs) # If separator is a string, apply the same separator for all roles if isinstance(self.sep, str): self.sep = [self.sep for _ in self.roles] def get_prompt(self): """Constructs the prompt for the conversation based on the current history.""" # If the separator is a string, create a dictionary that maps roles to separators sep = dict(zip(self.roles, self.sep)) # Start the prompt with the system message if available ret = self.system + sep[self.roles[0]] if self.system != "" else "" # Loop through the messages and format the prompt for i, (role, message) in enumerate(self.messages): if i + 1 == len(self.messages): # For the last message ret += role + message # Just add role and message else: ret += role + message + sep[role] # Add role, message, and separator for other messages return ret def user_query(self, query=None, pre_query_prompt=None, post_query_prompt=None, is_mm=False, num_mm_token=1): """Append the user's query to the conversation.""" # Add post-query and pre-query prompts if they exist if post_query_prompt is not None: query = f"{query}{post_query_prompt}" if pre_query_prompt is not None: query = f"{pre_query_prompt}{query}" role = self.roles[0] # User role # Handle multi-modal input if is_mm: mm_str = num_mm_token * self.mm_token[:-1] + self.mm_token[-1] if self.mm_style == MultiModalConvStyle.MM_ALONE: self._append_message(role, mm_str) # MM token as a standalone message elif self.mm_style == MultiModalConvStyle.MM_INTERLEAF: if self.mm_token not in query: query = f'{mm_str}{query}' # Interleave MM token with query self._append_message(role, query) def assistant_response(self, response, pre_answer_prompt=None, post_answer_prompt=None): """Append the assistant's response to the conversation.""" # Add post-query and pre-query prompts if they exist if pre_answer_prompt is not None: response = f"{pre_answer_prompt}{response}" if post_answer_prompt is not None: response = f"{response}{post_answer_prompt}" role = self.roles[1] # Assistant role self._append_message(role, response) def _append_message(self, role, message): """Helper function to append a message to the conversation.""" message = '' if message is None else message # Handle None messages self.messages.append([role, message]) def copy(self): """Return a deep copy of the conversation.""" return copy.deepcopy(self) llava_next_video_template = Conversation( system="", roles=("USER: ","ASSISTANT:"), messages=[], sep=[" ",""], mm_token='