|
|
import functools
|
|
|
import json
|
|
|
import re
|
|
|
import sys
|
|
|
import time
|
|
|
from typing import List, Mapping
|
|
|
|
|
|
from jinja2 import Template
|
|
|
|
|
|
|
|
|
def validate_role(role: str, valid_roles: List[str] = None):
|
|
|
if not valid_roles:
|
|
|
valid_roles = ["assistant", "function", "user", "system"]
|
|
|
|
|
|
if role not in valid_roles:
|
|
|
valid_roles_str = ','.join([f'\'{role}:\\n\'' for role in valid_roles])
|
|
|
raise ValueError(f"Invalid role: {role}. Valid roles are: {valid_roles_str}")
|
|
|
|
|
|
|
|
|
def try_parse_name_and_content(role_prompt):
|
|
|
|
|
|
|
|
|
pattern = r"\n*#{0,2}\s*name:\n+\s*(\S+)\s*\n*#{0,2}\s*content:\n?(.*)"
|
|
|
match = re.search(pattern, role_prompt, re.DOTALL)
|
|
|
if match:
|
|
|
return match.group(1), match.group(2)
|
|
|
return None
|
|
|
|
|
|
|
|
|
def parse_chat(chat_str, images: List = None, valid_roles: List[str] = None):
|
|
|
if not valid_roles:
|
|
|
valid_roles = ["system", "user", "assistant", "function"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
separator = r"(?i)^\s*#?\s*(" + "|".join(valid_roles) + r")\s*:\s*\n"
|
|
|
|
|
|
images = images or []
|
|
|
hash2images = {str(x): x for x in images}
|
|
|
|
|
|
chunks = re.split(separator, chat_str, flags=re.MULTILINE)
|
|
|
chat_list = []
|
|
|
|
|
|
for chunk in chunks:
|
|
|
last_message = chat_list[-1] if len(chat_list) > 0 else None
|
|
|
if last_message and "role" in last_message and "content" not in last_message:
|
|
|
parsed_result = try_parse_name_and_content(chunk)
|
|
|
if parsed_result is None:
|
|
|
|
|
|
if last_message["role"] == "function":
|
|
|
raise ValueError("Function role must have content.")
|
|
|
|
|
|
else:
|
|
|
last_message["content"] = to_content_str_or_list(chunk, hash2images)
|
|
|
else:
|
|
|
last_message["name"] = parsed_result[0]
|
|
|
last_message["content"] = to_content_str_or_list(parsed_result[1], hash2images)
|
|
|
else:
|
|
|
if chunk.strip() == "":
|
|
|
continue
|
|
|
|
|
|
|
|
|
role = chunk.strip().lower()
|
|
|
validate_role(role, valid_roles=valid_roles)
|
|
|
new_message = {"role": role}
|
|
|
chat_list.append(new_message)
|
|
|
return chat_list
|
|
|
|
|
|
|
|
|
def to_content_str_or_list(chat_str: str, hash2images: Mapping):
|
|
|
chat_str = chat_str.strip()
|
|
|
chunks = chat_str.split("\n")
|
|
|
include_image = False
|
|
|
result = []
|
|
|
for chunk in chunks:
|
|
|
if chunk.strip() in hash2images:
|
|
|
image_message = {}
|
|
|
image_message["type"] = "image_url"
|
|
|
image_url = hash2images[chunk.strip()].source_url \
|
|
|
if hasattr(hash2images[chunk.strip()], "source_url") else None
|
|
|
if not image_url:
|
|
|
image_bs64 = hash2images[chunk.strip()].to_base64()
|
|
|
image_mine_type = hash2images[chunk.strip()]._mime_type
|
|
|
image_url = {"url": f"data:{image_mine_type};base64,{image_bs64}"}
|
|
|
image_message["image_url"] = image_url
|
|
|
result.append(image_message)
|
|
|
include_image = True
|
|
|
elif chunk.strip() == "":
|
|
|
continue
|
|
|
else:
|
|
|
result.append({"type": "text", "text": chunk})
|
|
|
return result if include_image else chat_str
|
|
|
|
|
|
|