long / prompts /pf_parse_chat.py
deeme's picture
Upload 111 files
217acfe verified
import functools
import json
import re
import sys
import time
from typing import List, Mapping
from jinja2 import Template
def validate_role(role: str, valid_roles: List[str] = None):
if not valid_roles:
valid_roles = ["assistant", "function", "user", "system"]
if role not in valid_roles:
valid_roles_str = ','.join([f'\'{role}:\\n\'' for role in valid_roles])
raise ValueError(f"Invalid role: {role}. Valid roles are: {valid_roles_str}")
def try_parse_name_and_content(role_prompt):
# customer can add ## in front of name/content for markdown highlight.
# and we still support name/content without ## prefix for backward compatibility.
pattern = r"\n*#{0,2}\s*name:\n+\s*(\S+)\s*\n*#{0,2}\s*content:\n?(.*)"
match = re.search(pattern, role_prompt, re.DOTALL)
if match:
return match.group(1), match.group(2)
return None
def parse_chat(chat_str, images: List = None, valid_roles: List[str] = None):
if not valid_roles:
valid_roles = ["system", "user", "assistant", "function"]
# openai chat api only supports below roles.
# customer can add single # in front of role name for markdown highlight.
# and we still support role name without # prefix for backward compatibility.
separator = r"(?i)^\s*#?\s*(" + "|".join(valid_roles) + r")\s*:\s*\n"
images = images or []
hash2images = {str(x): x for x in images}
chunks = re.split(separator, chat_str, flags=re.MULTILINE)
chat_list = []
for chunk in chunks:
last_message = chat_list[-1] if len(chat_list) > 0 else None
if last_message and "role" in last_message and "content" not in last_message:
parsed_result = try_parse_name_and_content(chunk)
if parsed_result is None:
# "name" is required if the role is "function"
if last_message["role"] == "function":
raise ValueError("Function role must have content.")
# "name" is optional for other role types.
else:
last_message["content"] = to_content_str_or_list(chunk, hash2images)
else:
last_message["name"] = parsed_result[0]
last_message["content"] = to_content_str_or_list(parsed_result[1], hash2images)
else:
if chunk.strip() == "":
continue
# Check if prompt follows chat api message format and has valid role.
# References: https://platform.openai.com/docs/api-reference/chat/create.
role = chunk.strip().lower()
validate_role(role, valid_roles=valid_roles)
new_message = {"role": role}
chat_list.append(new_message)
return chat_list
def to_content_str_or_list(chat_str: str, hash2images: Mapping):
chat_str = chat_str.strip()
chunks = chat_str.split("\n")
include_image = False
result = []
for chunk in chunks:
if chunk.strip() in hash2images:
image_message = {}
image_message["type"] = "image_url"
image_url = hash2images[chunk.strip()].source_url \
if hasattr(hash2images[chunk.strip()], "source_url") else None
if not image_url:
image_bs64 = hash2images[chunk.strip()].to_base64()
image_mine_type = hash2images[chunk.strip()]._mime_type
image_url = {"url": f"data:{image_mine_type};base64,{image_bs64}"}
image_message["image_url"] = image_url
result.append(image_message)
include_image = True
elif chunk.strip() == "":
continue
else:
result.append({"type": "text", "text": chunk})
return result if include_image else chat_str