Spaces:
Running
Running
Rename bin_public/utils/tools.py to bin_public/utils/utils.py
Browse files
bin_public/utils/{tools.py → utils.py}
RENAMED
|
@@ -5,31 +5,38 @@ import logging
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
import datetime
|
| 8 |
-
import hashlib
|
| 9 |
import csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
import gradio as gr
|
| 12 |
from pypinyin import lazy_pinyin
|
| 13 |
import tiktoken
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from bin_public.config.presets import *
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
if TYPE_CHECKING:
|
| 20 |
from typing import TypedDict
|
| 21 |
|
|
|
|
| 22 |
class DataframeData(TypedDict):
|
| 23 |
headers: List[str]
|
| 24 |
data: List[List[str | int | bool]]
|
| 25 |
|
| 26 |
|
| 27 |
-
initial_prompt = "You are a helpful assistant."
|
| 28 |
-
API_URL = "https://api.openai.com/v1/chat/completions"
|
| 29 |
-
HISTORY_DIR = "history"
|
| 30 |
-
TEMPLATES_DIR = "templates"
|
| 31 |
-
|
| 32 |
-
|
| 33 |
def count_token(message):
|
| 34 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 35 |
input_str = f"role: {message['role']}, content: {message['content']}"
|
|
@@ -37,36 +44,99 @@ def count_token(message):
|
|
| 37 |
return length
|
| 38 |
|
| 39 |
|
| 40 |
-
def
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
for i, line in enumerate(lines):
|
| 45 |
-
if "
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
else:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def construct_text(role, text):
|
|
@@ -89,6 +159,17 @@ def construct_token_message(token, stream=False):
|
|
| 89 |
return f"Token 计数: {token}"
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def delete_last_conversation(chatbot, history, previous_token_count):
|
| 93 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
| 94 |
logging.info("由于包含报错信息,只删除chatbot记录")
|
|
@@ -210,7 +291,7 @@ def load_template(filename, mode=0):
|
|
| 210 |
lines = [[i["act"], i["prompt"]] for i in lines]
|
| 211 |
else:
|
| 212 |
with open(
|
| 213 |
-
|
| 214 |
) as csvfile:
|
| 215 |
reader = csv.reader(csvfile)
|
| 216 |
lines = list(reader)
|
|
@@ -245,20 +326,19 @@ def reset_state():
|
|
| 245 |
|
| 246 |
|
| 247 |
def reset_textbox():
|
|
|
|
| 248 |
return gr.update(value="")
|
| 249 |
|
| 250 |
|
| 251 |
def reset_default():
|
| 252 |
-
|
| 253 |
-
API_URL = "https://api.openai.com/v1/chat/completions"
|
| 254 |
os.environ.pop("HTTPS_PROXY", None)
|
| 255 |
os.environ.pop("https_proxy", None)
|
| 256 |
-
return gr.update(value=
|
| 257 |
|
| 258 |
|
| 259 |
def change_api_url(url):
|
| 260 |
-
|
| 261 |
-
API_URL = url
|
| 262 |
msg = f"API地址更改为了{url}"
|
| 263 |
logging.info(msg)
|
| 264 |
return msg
|
|
@@ -288,12 +368,138 @@ def submit_key(key):
|
|
| 288 |
return key, msg
|
| 289 |
|
| 290 |
|
| 291 |
-
def sha1sum(filename):
|
| 292 |
-
sha1 = hashlib.sha1()
|
| 293 |
-
sha1.update(filename.encode("utf-8"))
|
| 294 |
-
return sha1.hexdigest()
|
| 295 |
-
|
| 296 |
-
|
| 297 |
def replace_today(prompt):
|
| 298 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
| 299 |
-
return prompt.replace("{current_date}", today)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
import datetime
|
|
|
|
| 8 |
import csv
|
| 9 |
+
import requests
|
| 10 |
+
import re
|
| 11 |
+
import html
|
| 12 |
+
import sys
|
| 13 |
+
import subprocess
|
| 14 |
|
|
|
|
| 15 |
from pypinyin import lazy_pinyin
|
| 16 |
import tiktoken
|
| 17 |
+
import mdtex2html
|
| 18 |
+
from markdown import markdown
|
| 19 |
+
from pygments import highlight
|
| 20 |
+
from pygments.lexers import get_lexer_by_name
|
| 21 |
+
from pygments.formatters import HtmlFormatter
|
| 22 |
|
| 23 |
from bin_public.config.presets import *
|
| 24 |
+
import bin_public.utils.shared as shared
|
| 25 |
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
| 29 |
+
)
|
| 30 |
|
| 31 |
if TYPE_CHECKING:
|
| 32 |
from typing import TypedDict
|
| 33 |
|
| 34 |
+
|
| 35 |
class DataframeData(TypedDict):
|
| 36 |
headers: List[str]
|
| 37 |
data: List[List[str | int | bool]]
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def count_token(message):
|
| 41 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 42 |
input_str = f"role: {message['role']}, content: {message['content']}"
|
|
|
|
| 44 |
return length
|
| 45 |
|
| 46 |
|
| 47 |
+
def markdown_to_html_with_syntax_highlight(md_str):
|
| 48 |
+
def replacer(match):
|
| 49 |
+
lang = match.group(1) or "text"
|
| 50 |
+
code = match.group(2)
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
lexer = get_lexer_by_name(lang, stripall=True)
|
| 54 |
+
except ValueError:
|
| 55 |
+
lexer = get_lexer_by_name("text", stripall=True)
|
| 56 |
+
|
| 57 |
+
formatter = HtmlFormatter()
|
| 58 |
+
highlighted_code = highlight(code, lexer, formatter)
|
| 59 |
+
|
| 60 |
+
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
|
| 61 |
+
|
| 62 |
+
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
|
| 63 |
+
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
| 64 |
+
|
| 65 |
+
html_str = markdown(md_str)
|
| 66 |
+
return html_str
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def normalize_markdown(md_text: str) -> str:
|
| 70 |
+
lines = md_text.split("\n")
|
| 71 |
+
normalized_lines = []
|
| 72 |
+
inside_list = False
|
| 73 |
+
|
| 74 |
for i, line in enumerate(lines):
|
| 75 |
+
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
|
| 76 |
+
if not inside_list and i > 0 and lines[i - 1].strip() != "":
|
| 77 |
+
normalized_lines.append("")
|
| 78 |
+
inside_list = True
|
| 79 |
+
normalized_lines.append(line)
|
| 80 |
+
elif inside_list and line.strip() == "":
|
| 81 |
+
if i < len(lines) - 1 and not re.match(
|
| 82 |
+
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
|
| 83 |
+
):
|
| 84 |
+
normalized_lines.append(line)
|
| 85 |
+
continue
|
| 86 |
else:
|
| 87 |
+
inside_list = False
|
| 88 |
+
normalized_lines.append(line)
|
| 89 |
+
|
| 90 |
+
return "\n".join(normalized_lines)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def convert_mdtext(md_text):
|
| 94 |
+
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
| 95 |
+
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
|
| 96 |
+
code_blocks = code_block_pattern.findall(md_text)
|
| 97 |
+
non_code_parts = code_block_pattern.split(md_text)[::2]
|
| 98 |
+
|
| 99 |
+
result = []
|
| 100 |
+
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
| 101 |
+
if non_code.strip():
|
| 102 |
+
non_code = normalize_markdown(non_code)
|
| 103 |
+
if inline_code_pattern.search(non_code):
|
| 104 |
+
result.append(markdown(non_code, extensions=["tables"]))
|
| 105 |
+
else:
|
| 106 |
+
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
|
| 107 |
+
if code.strip():
|
| 108 |
+
# _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
|
| 109 |
+
# code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题
|
| 110 |
+
code = f"\n```{code}\n\n```"
|
| 111 |
+
code = markdown_to_html_with_syntax_highlight(code)
|
| 112 |
+
result.append(code)
|
| 113 |
+
result = "".join(result)
|
| 114 |
+
result += ALREADY_CONVERTED_MARK
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def convert_asis(userinput):
|
| 119 |
+
return (
|
| 120 |
+
f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
|
| 121 |
+
+ ALREADY_CONVERTED_MARK
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def detect_converted_mark(userinput):
|
| 126 |
+
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
| 127 |
+
return True
|
| 128 |
+
else:
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def detect_language(code):
|
| 133 |
+
if code.startswith("\n"):
|
| 134 |
+
first_line = ""
|
| 135 |
+
else:
|
| 136 |
+
first_line = code.strip().split("\n", 1)[0]
|
| 137 |
+
language = first_line.lower() if first_line else ""
|
| 138 |
+
code_without_language = code[len(first_line):].lstrip() if first_line else code
|
| 139 |
+
return language, code_without_language
|
| 140 |
|
| 141 |
|
| 142 |
def construct_text(role, text):
|
|
|
|
| 159 |
return f"Token 计数: {token}"
|
| 160 |
|
| 161 |
|
| 162 |
+
def delete_first_conversation(history, previous_token_count):
|
| 163 |
+
if history:
|
| 164 |
+
del history[:2]
|
| 165 |
+
del previous_token_count[0]
|
| 166 |
+
return (
|
| 167 |
+
history,
|
| 168 |
+
previous_token_count,
|
| 169 |
+
construct_token_message(sum(previous_token_count)),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
def delete_last_conversation(chatbot, history, previous_token_count):
|
| 174 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
| 175 |
logging.info("由于包含报错信息,只删除chatbot记录")
|
|
|
|
| 291 |
lines = [[i["act"], i["prompt"]] for i in lines]
|
| 292 |
else:
|
| 293 |
with open(
|
| 294 |
+
os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
|
| 295 |
) as csvfile:
|
| 296 |
reader = csv.reader(csvfile)
|
| 297 |
lines = list(reader)
|
|
|
|
| 326 |
|
| 327 |
|
| 328 |
def reset_textbox():
|
| 329 |
+
logging.debug("重置文本框")
|
| 330 |
return gr.update(value="")
|
| 331 |
|
| 332 |
|
| 333 |
def reset_default():
|
| 334 |
+
newurl = shared.state.reset_api_url()
|
|
|
|
| 335 |
os.environ.pop("HTTPS_PROXY", None)
|
| 336 |
os.environ.pop("https_proxy", None)
|
| 337 |
+
return gr.update(value=newurl), gr.update(value=""), "API URL 和代理已重置"
|
| 338 |
|
| 339 |
|
| 340 |
def change_api_url(url):
|
| 341 |
+
shared.state.set_api_url(url)
|
|
|
|
| 342 |
msg = f"API地址更改为了{url}"
|
| 343 |
logging.info(msg)
|
| 344 |
return msg
|
|
|
|
| 368 |
return key, msg
|
| 369 |
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
def replace_today(prompt):
|
| 372 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
| 373 |
+
return prompt.replace("{current_date}", today)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def get_geoip():
|
| 377 |
+
response = requests.get("https://ipapi.co/json/", timeout=5)
|
| 378 |
+
try:
|
| 379 |
+
data = response.json()
|
| 380 |
+
except:
|
| 381 |
+
data = {"error": True, "reason": "连接ipapi失败"}
|
| 382 |
+
if "error" in data.keys():
|
| 383 |
+
logging.warning(f"无法获取IP地址信息。\n{data}")
|
| 384 |
+
if data["reason"] == "RateLimited":
|
| 385 |
+
return (
|
| 386 |
+
f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
| 390 |
+
else:
|
| 391 |
+
country = data["country_name"]
|
| 392 |
+
if country == "China":
|
| 393 |
+
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
|
| 394 |
+
else:
|
| 395 |
+
text = f"您的IP区域:{country}。"
|
| 396 |
+
logging.info(text)
|
| 397 |
+
return text
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def find_n(lst, max_num):
|
| 401 |
+
n = len(lst)
|
| 402 |
+
total = sum(lst)
|
| 403 |
+
|
| 404 |
+
if total < max_num:
|
| 405 |
+
return n
|
| 406 |
+
|
| 407 |
+
for i in range(len(lst)):
|
| 408 |
+
if total - lst[i] < max_num:
|
| 409 |
+
return n - i - 1
|
| 410 |
+
total = total - lst[i]
|
| 411 |
+
return 1
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def start_outputing():
|
| 415 |
+
logging.debug("显示取消按钮,隐藏发送按钮")
|
| 416 |
+
return gr.Button.update(visible=False), gr.Button.update(visible=True)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def end_outputing():
|
| 420 |
+
return (
|
| 421 |
+
gr.Button.update(visible=True),
|
| 422 |
+
gr.Button.update(visible=False),
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def cancel_outputing():
|
| 427 |
+
logging.info("中止输出……")
|
| 428 |
+
shared.state.interrupt()
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def transfer_input(inputs):
|
| 432 |
+
# 一次性返回,降低延迟
|
| 433 |
+
textbox = reset_textbox()
|
| 434 |
+
outputing = start_outputing()
|
| 435 |
+
return (
|
| 436 |
+
inputs,
|
| 437 |
+
gr.update(value=""),
|
| 438 |
+
gr.Button.update(visible=False),
|
| 439 |
+
gr.Button.update(visible=True),
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def get_proxies():
|
| 444 |
+
# 获取环境变量中的代理设置
|
| 445 |
+
http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
|
| 446 |
+
https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
|
| 447 |
+
|
| 448 |
+
# 如果存在代理设置,使用它们
|
| 449 |
+
proxies = {}
|
| 450 |
+
if http_proxy:
|
| 451 |
+
logging.info(f"使用 HTTP 代理: {http_proxy}")
|
| 452 |
+
proxies["http"] = http_proxy
|
| 453 |
+
if https_proxy:
|
| 454 |
+
logging.info(f"使用 HTTPS 代理: {https_proxy}")
|
| 455 |
+
proxies["https"] = https_proxy
|
| 456 |
+
|
| 457 |
+
if proxies == {}:
|
| 458 |
+
proxies = None
|
| 459 |
+
|
| 460 |
+
return proxies
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
| 464 |
+
if desc is not None:
|
| 465 |
+
print(desc)
|
| 466 |
+
if live:
|
| 467 |
+
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
| 468 |
+
if result.returncode != 0:
|
| 469 |
+
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
| 470 |
+
Command: {command}
|
| 471 |
+
Error code: {result.returncode}""")
|
| 472 |
+
|
| 473 |
+
return ""
|
| 474 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True,
|
| 475 |
+
env=os.environ if custom_env is None else custom_env)
|
| 476 |
+
if result.returncode != 0:
|
| 477 |
+
message = f"""{errdesc or 'Error running command'}.
|
| 478 |
+
Command: {command}
|
| 479 |
+
Error code: {result.returncode}
|
| 480 |
+
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else '<empty>'}
|
| 481 |
+
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else '<empty>'}
|
| 482 |
+
"""
|
| 483 |
+
raise RuntimeError(message)
|
| 484 |
+
return result.stdout.decode(encoding="utf8", errors="ignore")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def versions_html():
|
| 488 |
+
git = os.environ.get('GIT', "git")
|
| 489 |
+
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
| 490 |
+
try:
|
| 491 |
+
commit_hash = run(f"{git} rev-parse HEAD").strip()
|
| 492 |
+
except Exception:
|
| 493 |
+
commit_hash = "<none>"
|
| 494 |
+
if commit_hash != "<none>":
|
| 495 |
+
short_commit = commit_hash[0:7]
|
| 496 |
+
commit_info = f"<a style=\"text-decoration:none\" href=\"https://github.com/GaiZhenbiao/ChuanhuChatGPT/commit/{short_commit}\">{short_commit}</a>"
|
| 497 |
+
else:
|
| 498 |
+
commit_info = "unknown \U0001F615"
|
| 499 |
+
return f"""
|
| 500 |
+
Python: <span title="{sys.version}">{python_version}</span>
|
| 501 |
+
•
|
| 502 |
+
Gradio: {gr.__version__}
|
| 503 |
+
•
|
| 504 |
+
Commit: {commit_info}
|
| 505 |
+
"""
|