File size: 5,520 Bytes
a4b70d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from __future__ import annotations
import random
import string
from pathlib import Path
from ..typing import Messages, Cookies, AsyncIterator, Iterator
from ..tools.files import get_bucket_dir, read_bucket
from .. import debug
def to_string(value) -> str:
if isinstance(value, str):
return value
elif isinstance(value, dict):
if "text" in value:
return value["text"]
elif "name" in value:
return ""
elif "bucket_id" in value:
bucket_dir = Path(get_bucket_dir(value.get("bucket_id")))
return "".join(read_bucket(bucket_dir))
return ""
elif isinstance(value, list):
return "".join([to_string(v) for v in value if v.get("type", "text") == "text"])
elif value is None:
return ""
return str(value)
def render_messages(messages: Messages) -> Iterator:
for idx, message in enumerate(messages):
if isinstance(message, dict) and isinstance(message.get("content"), list):
yield {
**message,
"content": to_string(message["content"]),
}
else:
yield message
def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str:
"""
Format a series of messages into a single string, optionally adding special tokens.
Args:
messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
add_special_tokens (bool): Whether to add special formatting tokens.
Returns:
str: A formatted string containing all messages.
"""
if not add_special_tokens and len(messages) <= 1:
return to_string(messages[0]["content"])
messages = [
(message["role"], to_string(message["content"]))
for message in messages
if include_system or message.get("role") not in ("developer", "system")
]
formatted = "\n".join([
f'{role.capitalize()}: {content}'
for role, content in messages
if content.strip()
])
if do_continue:
return formatted
return f"{formatted}\nAssistant:"
def get_system_prompt(messages: Messages) -> str:
return "\n".join([m["content"] for m in messages if m["role"] in ("developer", "system")])
def get_last_user_message(messages: Messages, include_buckets: bool = True) -> str:
user_messages = []
for message in messages[::-1]:
if message.get("role") == "user" or not user_messages:
if message.get("role") != "user":
continue
content = message.get("content")
if include_buckets:
content = to_string(content).strip()
if isinstance(content, str):
user_messages.append(content)
else:
for content_item in content:
if content_item.get("type") == "text":
content = content_item.get("text").strip()
if content:
user_messages.append(content)
else:
return "\n".join(user_messages[::-1])
return "\n".join(user_messages[::-1])
def get_last_message(messages: Messages, prompt: str = None) -> str:
if prompt is None:
for message in messages[::-1]:
content = to_string(message.get("content")).strip()
if content:
prompt = content
return prompt
def format_media_prompt(messages, prompt: str = None) -> str:
if prompt is None:
return get_last_user_message(messages)
return prompt
def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
prompt = format_prompt(messages)
start = len(prompt)
if start > max_lenght:
if len(messages) > 6:
prompt = format_prompt(messages[:3] + messages[-3:])
if len(prompt) > max_lenght:
if len(messages) > 2:
prompt = format_prompt([m for m in messages if m["role"] == "system"] + messages[-1:])
if len(prompt) > max_lenght:
prompt = messages[-1]["content"]
debug.log(f"Messages trimmed from: {start} to: {len(prompt)}")
return prompt
def get_random_string(length: int = 10) -> str:
"""
Generate a random string of specified length, containing lowercase letters and digits.
Args:
length (int, optional): Length of the random string to generate. Defaults to 10.
Returns:
str: A random string of the specified length.
"""
return ''.join(
random.choice(string.ascii_lowercase + string.digits)
for _ in range(length)
)
def get_random_hex(length: int = 32) -> str:
"""
Generate a random hexadecimal string with n length.
Returns:
str: A random hexadecimal string of n characters.
"""
return ''.join(
random.choice("abcdef" + string.digits)
for _ in range(length)
)
def filter_none(**kwargs) -> dict:
return {
key: value
for key, value in kwargs.items()
if value is not None
}
async def async_concat_chunks(chunks: AsyncIterator) -> str:
return concat_chunks([chunk async for chunk in chunks])
def concat_chunks(chunks: Iterator) -> str:
return "".join([
str(chunk) for chunk in chunks
if chunk and not isinstance(chunk, Exception)
])
def format_cookies(cookies: Cookies) -> str:
return "; ".join([f"{k}={v}" for k, v in cookies.items()]) |