Upload 111 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env +86 -0
- Dockerfile +8 -0
- README.md +6 -5
- app.py +341 -0
- backend_utils.py +22 -0
- config.py +81 -0
- core/__init__.py +1 -0
- core/backend.py +218 -0
- core/diff_utils.py +173 -0
- core/draft_writer.py +46 -0
- core/frontend.py +435 -0
- core/frontend_copy.py +35 -0
- core/frontend_setting.py +345 -0
- core/frontend_utils.py +333 -0
- core/outline_writer.py +88 -0
- core/parser_utils.py +32 -0
- core/plot_writer.py +49 -0
- core/summary_novel.py +94 -0
- core/writer.py +533 -0
- core/writer_utils.py +216 -0
- custom/根据提纲创作正文/天蚕土豆风格.txt +14 -0
- custom/根据提纲创作正文/对草稿进行润色.txt +7 -0
- healthcheck.py +24 -0
- llm_api/__init__.py +109 -0
- llm_api/baidu_api.py +48 -0
- llm_api/chat_messages.py +116 -0
- llm_api/doubao_api.py +53 -0
- llm_api/model_prices.json +0 -0
- llm_api/mongodb_cache.py +127 -0
- llm_api/mongodb_cost.py +121 -0
- llm_api/mongodb_init.py +7 -0
- llm_api/openai_api.py +67 -0
- llm_api/sparkai_api.py +66 -0
- llm_api/zhipuai_api.py +54 -0
- prompts/baseprompt.py +105 -0
- prompts/chat_utils.py +40 -0
- prompts/common_parser.py +21 -0
- prompts/idea-examples.yaml +9 -0
- prompts/pf_parse_chat.py +94 -0
- prompts/prompt_utils.py +128 -0
- prompts/test_format_plot.yaml +28 -0
- prompts/test_prompt.py +22 -0
- prompts/text-plot-examples.yaml +227 -0
- prompts/tool_parser.py +39 -0
- prompts/tool_polish.py +23 -0
- prompts/创作剧情/context_prompt.txt +35 -0
- prompts/创作剧情/prompt.py +26 -0
- prompts/创作剧情/system_prompt.txt +49 -0
- prompts/创作剧情/扩写剧情.txt +14 -0
- prompts/创作剧情/新建剧情.txt +35 -0
.env
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Thread Configuration - 线程配置
|
| 2 |
+
# 生成时采用的最大线程数,5-10即可。会带来成倍的API调用费用,不要设置过高!
|
| 3 |
+
MAX_THREAD_NUM=1
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Server Configuration - Docker服务配置
|
| 7 |
+
# 前端服务端口
|
| 8 |
+
FRONTEND_PORT=80
|
| 9 |
+
# 后端服务端口
|
| 10 |
+
BACKEND_PORT=7869
|
| 11 |
+
# 后端服务监听地址
|
| 12 |
+
BACKEND_HOST=0.0.0.0
|
| 13 |
+
# Gunicorn工作进程数
|
| 14 |
+
WORKERS=4
|
| 15 |
+
# 每个工作进程的线程数
|
| 16 |
+
THREADS=2
|
| 17 |
+
# 请求超时时间(秒)
|
| 18 |
+
TIMEOUT=120
|
| 19 |
+
|
| 20 |
+
# 是否启用在线演示
|
| 21 |
+
# 不用设置,默认不启用
|
| 22 |
+
ENABLE_ONLINE_DEMO=False
|
| 23 |
+
|
| 24 |
+
# Backend Configuration - 后端配置
|
| 25 |
+
# 导入小说时,最大的处理长度,超出该长度的文本不会进行处理,可以考虑增加
|
| 26 |
+
MAX_NOVEL_SUMMARY_LENGTH=20000
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# MongoDB Configuration - MongoDB数据库配置
|
| 30 |
+
# 安装了MongoDB才需要配置,否则不用改动
|
| 31 |
+
# 是否启用MongoDB,启用后下面配置才有效
|
| 32 |
+
ENABLE_MONGODB=false
|
| 33 |
+
# MongoDB连接地址,使用host.docker.internal访问宿主机MongoDB
|
| 34 |
+
MONGODB_URI=mongodb://host.docker.internal:27017/
|
| 35 |
+
# MongoDB数据库名称
|
| 36 |
+
MONGODB_DB_NAME=llm_api
|
| 37 |
+
# 是否启用API缓存
|
| 38 |
+
ENABLE_MONGODB_CACHE=true
|
| 39 |
+
# 缓存命中后重放速度倍率
|
| 40 |
+
CACHE_REPLAY_SPEED=2
|
| 41 |
+
# 缓存命中后最大延迟时间(秒)
|
| 42 |
+
CACHE_REPLAY_MAX_DELAY=5
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# API Cost Limits - API费用限制设置,需要依赖于MongoDB
|
| 46 |
+
# 每小时费用上限(人民币)
|
| 47 |
+
API_HOURLY_LIMIT_RMB=100
|
| 48 |
+
# 每天费用上限(人民币)
|
| 49 |
+
API_DAILY_LIMIT_RMB=500
|
| 50 |
+
# 美元兑人民币汇率
|
| 51 |
+
API_USD_TO_RMB_RATE=7
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Wenxin API Settings - 文心API配置
|
| 55 |
+
# 文心API的AK,获取地址:https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
|
| 56 |
+
WENXIN_AK=
|
| 57 |
+
WENXIN_SK=
|
| 58 |
+
WENXIN_AVAILABLE_MODELS=ERNIE-Novel-8K,ERNIE-4.0-8K,ERNIE-3.5-8K
|
| 59 |
+
|
| 60 |
+
# Doubao API Settings - 豆包API配置
|
| 61 |
+
# DOUBAO_ENDPOINT_IDS和DOUBAO_AVAILABLE_MODELS一一对应,有几个模型就对应几个endpoint_id,这是豆包强制要求的
|
| 62 |
+
# 你可以自行设置DOUBAO_AVAILABLE_MODELS,不一定非要采用下面的
|
| 63 |
+
DOUBAO_API_KEY=
|
| 64 |
+
DOUBAO_ENDPOINT_IDS=
|
| 65 |
+
DOUBAO_AVAILABLE_MODELS=doubao-pro-32k,doubao-lite-32k
|
| 66 |
+
|
| 67 |
+
# GPT API Settings - GPT API配置
|
| 68 |
+
|
| 69 |
+
GPT_AVAILABLE_MODELS=deepseek-r1,gemini-2.0-pro-exp-02-05,lmsys/claude-3-5-sonnet-20241022,windsurf/claude-3-5-sonnet,claude-3-5-sonnet-20240620,o3-mini,gpt-4-turbo-2024-04-09,gemini-2.0-flash-thinking-exp
|
| 70 |
+
|
| 71 |
+
# Local Model Settings - 本地模型配置
|
| 72 |
+
# 本地模型配置需要把下面的localhost替换为host.docker.internal,把8000替换为你的本地大模型服务端口
|
| 73 |
+
# 把local-key替换为你的本地大模型服务API_KEY,把local-model-1替换为你的本地大模型服务模型名
|
| 74 |
+
# 并且docker启动方式有变化,详细参考readme
|
| 75 |
+
LOCAL_BASE_URL=http://localhost:8000/v1
|
| 76 |
+
LOCAL_API_KEY=local-key
|
| 77 |
+
LOCAL_AVAILABLE_MODELS=local-model-1
|
| 78 |
+
|
| 79 |
+
# Zhipuai API Settings - 智谱AI配置
|
| 80 |
+
ZHIPUAI_API_KEY=
|
| 81 |
+
ZHIPUAI_AVAILABLE_MODELS=glm-4-air,glm-4-flashx
|
| 82 |
+
|
| 83 |
+
# Default Model Settings - 默认模型设置
|
| 84 |
+
# 例如:wenxin/ERNIE-Novel-8K, doubao/doubao-pro-32k, gpt/gpt-4o-mini, local/local-model-1
|
| 85 |
+
DEFAULT_MAIN_MODEL=gpt/claude-3-5-sonnet-20240620
|
| 86 |
+
DEFAULT_SUB_MODEL=gpt/gemini-2.0-pro-exp-02-05
|
Dockerfile
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY . .
|
| 6 |
+
RUN pip install -r requirements.txt
|
| 7 |
+
|
| 8 |
+
CMD ["python app.py"]
|
README.md
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: long
|
| 3 |
+
emoji: 👀
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: gpl-3.0
|
| 9 |
+
app_port: 7869
|
| 10 |
---
|
| 11 |
|
|
|
app.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from flask import Flask, request, Response, jsonify
|
| 5 |
+
from flask_cors import CORS
|
| 6 |
+
app = Flask(__name__)
|
| 7 |
+
CORS(app)
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
|
| 14 |
+
from prompts.baseprompt import clean_txt_content, load_prompt
|
| 15 |
+
|
| 16 |
+
from core.writer_utils import KeyPointMsg
|
| 17 |
+
from core.draft_writer import DraftWriter
|
| 18 |
+
from core.plot_writer import PlotWriter
|
| 19 |
+
from core.outline_writer import OutlineWriter
|
| 20 |
+
|
| 21 |
+
from setting import setting_bp
|
| 22 |
+
from summary import process_novel
|
| 23 |
+
from backend_utils import get_model_config_from_provider_model
|
| 24 |
+
from config import MAX_NOVEL_SUMMARY_LENGTH, MAX_THREAD_NUM, ENABLE_ONLINE_DEMO
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
app.register_blueprint(setting_bp)
|
| 28 |
+
|
| 29 |
+
# 添加配置
|
| 30 |
+
BACKEND_HOST = os.environ.get('BACKEND_HOST', '0.0.0.0')
|
| 31 |
+
BACKEND_PORT = int(os.environ.get('BACKEND_PORT', 7869))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@app.route('/health', methods=['GET'])
|
| 35 |
+
def health_check():
|
| 36 |
+
return jsonify({
|
| 37 |
+
'status': 'healthy',
|
| 38 |
+
'timestamp': int(time.time())
|
| 39 |
+
}), 200
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num) -> DraftWriter:
|
| 43 |
+
kwargs = dict(
|
| 44 |
+
xy_pairs=chunk_list,
|
| 45 |
+
model=get_model_config_from_provider_model(main_model),
|
| 46 |
+
sub_model=get_model_config_from_provider_model(sub_model),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
kwargs['x_chunk_length'] = x_chunk_length
|
| 50 |
+
kwargs['y_chunk_length'] = y_chunk_length
|
| 51 |
+
kwargs['max_thread_num'] = max_thread_num
|
| 52 |
+
match writer_mode:
|
| 53 |
+
case 'draft':
|
| 54 |
+
kwargs['global_context'] = {}
|
| 55 |
+
novel_writer = DraftWriter(**kwargs)
|
| 56 |
+
case 'outline':
|
| 57 |
+
kwargs['global_context'] = {'summary': global_context}
|
| 58 |
+
novel_writer = OutlineWriter(**kwargs)
|
| 59 |
+
case 'plot':
|
| 60 |
+
kwargs['global_context'] = {'chapter': global_context}
|
| 61 |
+
novel_writer = PlotWriter(**kwargs)
|
| 62 |
+
case _:
|
| 63 |
+
raise ValueError(f"unknown writer: {writer_mode}")
|
| 64 |
+
|
| 65 |
+
return novel_writer
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
prompt_names = dict(
|
| 72 |
+
outline = ['新建章节', '扩写章节', '润色章节'],
|
| 73 |
+
plot = ['新建剧情', '扩写剧情', '润色剧情'],
|
| 74 |
+
draft = ['新建正文', '扩写正文', '润色正文'],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
prompt_dirname = dict(
|
| 78 |
+
outline = 'prompts/创作章节',
|
| 79 |
+
plot = 'prompts/创作剧情',
|
| 80 |
+
draft = 'prompts/创作正文',
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
PROMPTS = {}
|
| 85 |
+
for type_name, dirname in prompt_dirname.items():
|
| 86 |
+
PROMPTS[type_name] = {'prompt_names': prompt_names[type_name]}
|
| 87 |
+
for name in prompt_names[type_name]:
|
| 88 |
+
content = clean_txt_content(load_prompt(dirname, name))
|
| 89 |
+
if content.startswith("user:\n"):
|
| 90 |
+
content = content[len("user:\n"):]
|
| 91 |
+
PROMPTS[type_name][name] = {'content': content}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@app.route('/prompts', methods=['GET'])
|
| 95 |
+
def get_prompts():
|
| 96 |
+
return jsonify(PROMPTS)
|
| 97 |
+
|
| 98 |
+
def get_delta_chunks(prev_chunks, curr_chunks):
|
| 99 |
+
"""Calculate delta between previous and current chunks"""
|
| 100 |
+
if not prev_chunks or len(prev_chunks) != len(curr_chunks):
|
| 101 |
+
return "init", curr_chunks
|
| 102 |
+
|
| 103 |
+
# Check if all strings in current chunks start with their corresponding previous strings
|
| 104 |
+
is_delta = True
|
| 105 |
+
for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
|
| 106 |
+
if len(prev_chunk) != len(curr_chunk):
|
| 107 |
+
is_delta = False
|
| 108 |
+
break
|
| 109 |
+
for prev_str, curr_str in zip(prev_chunk, curr_chunk):
|
| 110 |
+
if not curr_str.startswith(prev_str):
|
| 111 |
+
is_delta = False
|
| 112 |
+
break
|
| 113 |
+
if not is_delta:
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
if not is_delta:
|
| 117 |
+
return "init", curr_chunks
|
| 118 |
+
|
| 119 |
+
# Calculate deltas
|
| 120 |
+
delta_chunks = []
|
| 121 |
+
for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
|
| 122 |
+
delta_chunk = []
|
| 123 |
+
for prev_str, curr_str in zip(prev_chunk, curr_chunk):
|
| 124 |
+
delta_str = curr_str[len(prev_str):]
|
| 125 |
+
delta_chunk.append(delta_str)
|
| 126 |
+
delta_chunks.append(delta_chunk)
|
| 127 |
+
|
| 128 |
+
return "delta", delta_chunks
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def call_write(writer_mode, chunk_list, global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
|
| 132 |
+
if ENABLE_ONLINE_DEMO:
|
| 133 |
+
if max_thread_num > MAX_THREAD_NUM:
|
| 134 |
+
raise Exception("在线Demo模型下,最大线程数不能超过" + str(MAX_THREAD_NUM) + "!")
|
| 135 |
+
|
| 136 |
+
# 输入的chunk_list中每个chunk需要加上换行,除了最后一个chunk(因为是从页面中各个chunk传来的)
|
| 137 |
+
chunk_list = [[e.strip() + ('\n' if e.strip() and rowi != len(chunk_list)-1 else '') for e in row] for rowi, row in enumerate(chunk_list)]
|
| 138 |
+
|
| 139 |
+
prev_chunks = None
|
| 140 |
+
def delta_wrapper(chunk_list, done=False, msg=None):
|
| 141 |
+
# 返回的chunk_list中每个chunk需要去掉换行
|
| 142 |
+
chunk_list = [[e.strip() for e in row] for row in chunk_list]
|
| 143 |
+
|
| 144 |
+
nonlocal prev_chunks
|
| 145 |
+
if prev_chunks is None:
|
| 146 |
+
prev_chunks = chunk_list
|
| 147 |
+
return {
|
| 148 |
+
"done": done,
|
| 149 |
+
"chunk_type": "init",
|
| 150 |
+
"chunk_list": chunk_list,
|
| 151 |
+
"msg": msg
|
| 152 |
+
}
|
| 153 |
+
else:
|
| 154 |
+
chunk_type, new_chunks = get_delta_chunks(prev_chunks, chunk_list)
|
| 155 |
+
prev_chunks = chunk_list
|
| 156 |
+
return {
|
| 157 |
+
"done": done,
|
| 158 |
+
"chunk_type": chunk_type,
|
| 159 |
+
"chunk_list": new_chunks,
|
| 160 |
+
"msg": msg
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
novel_writer = load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# draft需要映射,所以进行初始划分
|
| 167 |
+
if writer_mode == 'draft':
|
| 168 |
+
target_chunk = novel_writer.get_chunk(pair_span=chunk_span)
|
| 169 |
+
new_target_chunk = novel_writer.map_text_wo_llm(target_chunk)
|
| 170 |
+
novel_writer.apply_chunks([target_chunk], [new_target_chunk])
|
| 171 |
+
chunk_span = novel_writer.get_chunk_pair_span(new_target_chunk)
|
| 172 |
+
|
| 173 |
+
init_novel_writer = load_novel_writer(writer_mode, list(novel_writer.xy_pairs), global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
|
| 174 |
+
|
| 175 |
+
# TODO: writer.write 应该保证无论什么prompt,都能够同时适应y为空和y有值地情况
|
| 176 |
+
# 换句话说,就是虽然可以单列出一个"新建正文",但用扩写正文也能实现同样的效果。
|
| 177 |
+
generator = novel_writer.write(prompt_content, pair_span=chunk_span)
|
| 178 |
+
|
| 179 |
+
prompt_outputs = []
|
| 180 |
+
last_yield_time = time.time() # Initialize the last yield time
|
| 181 |
+
|
| 182 |
+
prompt_name = ''
|
| 183 |
+
for kp_msg in generator:
|
| 184 |
+
if isinstance(kp_msg, KeyPointMsg):
|
| 185 |
+
# 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
|
| 186 |
+
prompt_name = kp_msg.prompt_name
|
| 187 |
+
continue
|
| 188 |
+
else:
|
| 189 |
+
chunk_list = kp_msg
|
| 190 |
+
|
| 191 |
+
current_cost = 0
|
| 192 |
+
currency_symbol = ''
|
| 193 |
+
current_model = ''
|
| 194 |
+
data_chunks = []
|
| 195 |
+
prompt_outputs.clear()
|
| 196 |
+
for e in chunk_list:
|
| 197 |
+
if e is None: continue # e为None说明该chunk还未处理
|
| 198 |
+
output, chunk = e
|
| 199 |
+
if output is None: continue # output为None说明该chunk未yield就return,说明未调用llm
|
| 200 |
+
prompt_outputs.append(output)
|
| 201 |
+
current_text = ""
|
| 202 |
+
current_model = output['response_msgs'].model
|
| 203 |
+
current_cost += output['response_msgs'].cost
|
| 204 |
+
currency_symbol = output['response_msgs'].currency_symbol
|
| 205 |
+
if 'plot2text' in output:
|
| 206 |
+
current_text += f"正在建立映射关系..." + '\n'
|
| 207 |
+
else:
|
| 208 |
+
current_text = output['text']
|
| 209 |
+
data_chunks.append((chunk.x_chunk, chunk.y_chunk, current_text))
|
| 210 |
+
|
| 211 |
+
if only_prompt:
|
| 212 |
+
yield {'prompts': [e['response_msgs'] for e in prompt_outputs]}
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
current_time = time.time()
|
| 216 |
+
if current_time - last_yield_time >= 0.2: # Check if 0.2 seconds have passed
|
| 217 |
+
yield delta_wrapper(data_chunks, done=False, msg=f"正在 {prompt_name} ({len(prompt_outputs)} / {len(chunk_list)})" + f" 模型:{current_model} 花费:{current_cost:.5f}{currency_symbol}" if current_model else '')
|
| 218 |
+
last_yield_time = current_time # Update the last yield time
|
| 219 |
+
|
| 220 |
+
# 这里是计算出一个编辑上的更改,方便前端显示,后续diff功能将不由writer提供,因为这是为了显示的要求
|
| 221 |
+
data_chunks = init_novel_writer.diff_to(novel_writer, pair_span=chunk_span)
|
| 222 |
+
|
| 223 |
+
yield delta_wrapper(data_chunks, done=True, msg='创作完成!')
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
@app.route('/write', methods=['POST'])
|
| 227 |
+
def write():
|
| 228 |
+
data = request.json
|
| 229 |
+
writer_mode = data['writer_mode']
|
| 230 |
+
chunk_list = data['chunk_list']
|
| 231 |
+
chunk_span = data['chunk_span']
|
| 232 |
+
prompt_content = data['prompt_content']
|
| 233 |
+
x_chunk_length = data['x_chunk_length']
|
| 234 |
+
y_chunk_length = data['y_chunk_length']
|
| 235 |
+
main_model = data['main_model']
|
| 236 |
+
sub_model = data['sub_model']
|
| 237 |
+
global_context = data['global_context']
|
| 238 |
+
only_prompt = data['only_prompt']
|
| 239 |
+
|
| 240 |
+
# Update settings if provided
|
| 241 |
+
if 'settings' in data:
|
| 242 |
+
max_thread_num = data['settings']['MAX_THREAD_NUM']
|
| 243 |
+
|
| 244 |
+
# Generate unique stream ID
|
| 245 |
+
stream_id = str(time.time())
|
| 246 |
+
active_streams[stream_id] = True
|
| 247 |
+
|
| 248 |
+
def generate():
|
| 249 |
+
try:
|
| 250 |
+
# Send stream ID to client
|
| 251 |
+
yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
|
| 252 |
+
|
| 253 |
+
for result in call_write(writer_mode, list(chunk_list), global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
|
| 254 |
+
if not active_streams.get(stream_id, False):
|
| 255 |
+
# Stream was stopped by client
|
| 256 |
+
print(f"Stream was stopped by client: {stream_id}")
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
+
yield f"data: {json.dumps(result)}\n\n"
|
| 260 |
+
except Exception as e:
|
| 261 |
+
error_msg = f"创作出错:\n{str(e)}"
|
| 262 |
+
error_chunk_list = [[*e[:2], error_msg] for e in chunk_list[chunk_span[0]:chunk_span[1]]]
|
| 263 |
+
|
| 264 |
+
error_data = {
|
| 265 |
+
"done": True,
|
| 266 |
+
"chunk_type": "init",
|
| 267 |
+
"chunk_list": error_chunk_list
|
| 268 |
+
}
|
| 269 |
+
yield f"data: {json.dumps(error_data)}\n\n"
|
| 270 |
+
finally:
|
| 271 |
+
# Clean up stream tracking
|
| 272 |
+
if stream_id in active_streams:
|
| 273 |
+
del active_streams[stream_id]
|
| 274 |
+
|
| 275 |
+
return Response(generate(), mimetype='text/event-stream')
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@app.route('/summary', methods=['POST'])
|
| 279 |
+
def process_novel_text():
|
| 280 |
+
data = request.json
|
| 281 |
+
content = data['content']
|
| 282 |
+
novel_name = data['novel_name']
|
| 283 |
+
|
| 284 |
+
# Generate unique stream ID
|
| 285 |
+
stream_id = str(time.time())
|
| 286 |
+
active_streams[stream_id] = True
|
| 287 |
+
|
| 288 |
+
def generate():
|
| 289 |
+
try:
|
| 290 |
+
yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
|
| 291 |
+
|
| 292 |
+
main_model = get_model_config_from_provider_model(data['main_model'])
|
| 293 |
+
sub_model = get_model_config_from_provider_model(data['sub_model'])
|
| 294 |
+
max_novel_summary_length = data['settings']['MAX_NOVEL_SUMMARY_LENGTH']
|
| 295 |
+
max_thread_num = data['settings']['MAX_THREAD_NUM']
|
| 296 |
+
last_yield_time = 0
|
| 297 |
+
for result in process_novel(content, novel_name, main_model, sub_model, max_novel_summary_length, max_thread_num):
|
| 298 |
+
if not active_streams.get(stream_id, False):
|
| 299 |
+
# Stream was stopped by client
|
| 300 |
+
print(f"Stream was stopped by client: {stream_id}")
|
| 301 |
+
return
|
| 302 |
+
|
| 303 |
+
current_time = time.time()
|
| 304 |
+
yield_value = f"data: {json.dumps(result)}\n\n"
|
| 305 |
+
if current_time - last_yield_time >= 0.2:
|
| 306 |
+
last_yield_time = current_time
|
| 307 |
+
yield yield_value
|
| 308 |
+
if current_time - last_yield_time < 0.2:
|
| 309 |
+
# Save last yield to yaml file
|
| 310 |
+
import yaml
|
| 311 |
+
result_dict = json.loads(yield_value.replace('data: ', '').strip())
|
| 312 |
+
with open('tmp.yaml', 'w', encoding='utf-8') as f:
|
| 313 |
+
yaml.dump(result_dict, f, allow_unicode=True)
|
| 314 |
+
|
| 315 |
+
yield yield_value # Ensure last yield is returned
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
error_data = {
|
| 319 |
+
"progress_msg": f"处理出错:{str(e)}",
|
| 320 |
+
}
|
| 321 |
+
yield f"data: {json.dumps(error_data)}\n\n"
|
| 322 |
+
finally:
|
| 323 |
+
# Clean up stream tracking
|
| 324 |
+
if stream_id in active_streams:
|
| 325 |
+
del active_streams[stream_id]
|
| 326 |
+
|
| 327 |
+
return Response(generate(), mimetype='text/event-stream')
|
| 328 |
+
|
| 329 |
+
# Dictionary to track active streams
|
| 330 |
+
active_streams = {}
|
| 331 |
+
|
| 332 |
+
@app.route('/stop_stream', methods=['POST'])
|
| 333 |
+
def stop_stream():
|
| 334 |
+
data = request.json
|
| 335 |
+
stream_id = data.get('stream_id')
|
| 336 |
+
if stream_id in active_streams:
|
| 337 |
+
active_streams[stream_id] = False
|
| 338 |
+
return jsonify({'success': True})
|
| 339 |
+
|
| 340 |
+
if __name__ == '__main__':
|
| 341 |
+
app.run(host=BACKEND_HOST, port=BACKEND_PORT, debug=False)
|
backend_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llm_api import ModelConfig
|
| 2 |
+
|
| 3 |
+
def get_model_config_from_provider_model(provider_model):
|
| 4 |
+
from config import API_SETTINGS
|
| 5 |
+
provider, model = provider_model.split('/', 1)
|
| 6 |
+
provider_config = API_SETTINGS[provider]
|
| 7 |
+
|
| 8 |
+
if provider == 'doubao':
|
| 9 |
+
# Get the index of the model in available_models to find corresponding endpoint_id
|
| 10 |
+
model_index = provider_config['available_models'].index(model)
|
| 11 |
+
endpoint_id = provider_config['endpoint_ids'][model_index] if model_index < len(provider_config['endpoint_ids']) else ''
|
| 12 |
+
model_config = {**provider_config, 'model': model, 'endpoint_id': endpoint_id}
|
| 13 |
+
else:
|
| 14 |
+
model_config = {**provider_config, 'model': model}
|
| 15 |
+
|
| 16 |
+
# Remove lists from config before creating ModelConfig
|
| 17 |
+
if 'available_models' in model_config:
|
| 18 |
+
del model_config['available_models']
|
| 19 |
+
if 'endpoint_ids' in model_config:
|
| 20 |
+
del model_config['endpoint_ids']
|
| 21 |
+
|
| 22 |
+
return ModelConfig(**model_config)
|
config.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import dotenv_values, load_dotenv
|
| 3 |
+
|
| 4 |
+
print("Loading .env file...")
|
| 5 |
+
env_path = os.path.join(os.path.dirname(__file__), '.env')
|
| 6 |
+
if os.path.exists(env_path):
|
| 7 |
+
env_dict = dotenv_values(env_path)
|
| 8 |
+
|
| 9 |
+
print("Environment variables to be loaded:")
|
| 10 |
+
for key, value in env_dict.items():
|
| 11 |
+
print(f"{key}={value}")
|
| 12 |
+
print("-" * 50)
|
| 13 |
+
|
| 14 |
+
os.environ.update(env_dict)
|
| 15 |
+
print(f"Loaded environment variables from: {env_path}")
|
| 16 |
+
else:
|
| 17 |
+
print("Warning: .env file not found")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Thread Configuration
|
| 21 |
+
MAX_THREAD_NUM = int(os.getenv('MAX_THREAD_NUM', 5))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
MAX_NOVEL_SUMMARY_LENGTH = int(os.getenv('MAX_NOVEL_SUMMARY_LENGTH', 20000))
|
| 25 |
+
|
| 26 |
+
# MongoDB Configuration
|
| 27 |
+
ENABLE_MONOGODB = os.getenv('ENABLE_MONGODB', 'false').lower() == 'true'
|
| 28 |
+
MONGODB_URI = os.getenv('MONGODB_URI', 'mongodb://127.0.0.1:27017/')
|
| 29 |
+
MONOGODB_DB_NAME = os.getenv('MONGODB_DB_NAME', 'llm_api')
|
| 30 |
+
ENABLE_MONOGODB_CACHE = os.getenv('ENABLE_MONGODB_CACHE', 'true').lower() == 'true'
|
| 31 |
+
CACHE_REPLAY_SPEED = float(os.getenv('CACHE_REPLAY_SPEED', 2))
|
| 32 |
+
CACHE_REPLAY_MAX_DELAY = float(os.getenv('CACHE_REPLAY_MAX_DELAY', 5))
|
| 33 |
+
|
| 34 |
+
# API Cost Limits
|
| 35 |
+
API_COST_LIMITS = {
|
| 36 |
+
'HOURLY_LIMIT_RMB': float(os.getenv('API_HOURLY_LIMIT_RMB', 100)),
|
| 37 |
+
'DAILY_LIMIT_RMB': float(os.getenv('API_DAILY_LIMIT_RMB', 500)),
|
| 38 |
+
'USD_TO_RMB_RATE': float(os.getenv('API_USD_TO_RMB_RATE', 7))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# API Settings
|
| 42 |
+
API_SETTINGS = {
|
| 43 |
+
'wenxin': {
|
| 44 |
+
'ak': os.getenv('WENXIN_AK', ''),
|
| 45 |
+
'sk': os.getenv('WENXIN_SK', ''),
|
| 46 |
+
'available_models': os.getenv('WENXIN_AVAILABLE_MODELS', '').split(','),
|
| 47 |
+
'max_tokens': 4096,
|
| 48 |
+
},
|
| 49 |
+
'doubao': {
|
| 50 |
+
'api_key': os.getenv('DOUBAO_API_KEY', ''),
|
| 51 |
+
'endpoint_ids': os.getenv('DOUBAO_ENDPOINT_IDS', '').split(','),
|
| 52 |
+
'available_models': os.getenv('DOUBAO_AVAILABLE_MODELS', '').split(','),
|
| 53 |
+
'max_tokens': 4096,
|
| 54 |
+
},
|
| 55 |
+
'gpt': {
|
| 56 |
+
'base_url': os.getenv('GPT_BASE_URL', ''),
|
| 57 |
+
'api_key': os.getenv('GPT_API_KEY', ''),
|
| 58 |
+
'proxies': os.getenv('GPT_PROXIES', ''),
|
| 59 |
+
'available_models': os.getenv('GPT_AVAILABLE_MODELS', '').split(','),
|
| 60 |
+
'max_tokens': 4096,
|
| 61 |
+
},
|
| 62 |
+
'zhipuai': {
|
| 63 |
+
'api_key': os.getenv('ZHIPUAI_API_KEY', ''),
|
| 64 |
+
'available_models': os.getenv('ZHIPUAI_AVAILABLE_MODELS', '').split(','),
|
| 65 |
+
'max_tokens': 4096,
|
| 66 |
+
},
|
| 67 |
+
'local': {
|
| 68 |
+
'base_url': os.getenv('LOCAL_BASE_URL', ''),
|
| 69 |
+
'api_key': os.getenv('LOCAL_API_KEY', ''),
|
| 70 |
+
'available_models': os.getenv('LOCAL_AVAILABLE_MODELS', '').split(','),
|
| 71 |
+
'max_tokens': 4096,
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
for model in API_SETTINGS.values():
|
| 76 |
+
model['available_models'] = [e.strip() for e in model['available_models']]
|
| 77 |
+
|
| 78 |
+
DEFAULT_MAIN_MODEL = os.getenv('DEFAULT_MAIN_MODEL', 'wenxin/ERNIE-Novel-8K')
|
| 79 |
+
DEFAULT_SUB_MODEL = os.getenv('DEFAULT_SUB_MODEL', 'wenxin/ERNIE-3.5-8K')
|
| 80 |
+
|
| 81 |
+
ENABLE_ONLINE_DEMO = os.getenv('ENABLE_ONLINE_DEMO', 'false').lower() == 'true'
|
core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# core 模块为LongNovelGPT到2.0版本之间的过渡,在core模块中进行一些新功能和设计的尝试
|
core/backend.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import importlib
|
| 3 |
+
from core.draft_writer import DraftWriter
|
| 4 |
+
from core.plot_writer import PlotWriter
|
| 5 |
+
from core.outline_writer import OutlineWriter
|
| 6 |
+
from core.writer_utils import KeyPointMsg
|
| 7 |
+
from core.diff_utils import match_span_by_char
|
| 8 |
+
import copy
|
| 9 |
+
import types
|
| 10 |
+
|
| 11 |
+
def load_novel_writer(writer, setting) -> DraftWriter:
|
| 12 |
+
current_w_name = writer['current_w']
|
| 13 |
+
current_w = writer[current_w_name]
|
| 14 |
+
|
| 15 |
+
kwargs = dict(
|
| 16 |
+
xy_pairs=list(current_w.get('xy_pairs', [['', '']])),
|
| 17 |
+
model=setting['model'],
|
| 18 |
+
sub_model=setting['sub_model'],
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
kwargs['x_chunk_length'] = current_w['x_chunk_length']
|
| 22 |
+
kwargs['y_chunk_length'] = current_w['y_chunk_length']
|
| 23 |
+
|
| 24 |
+
match current_w_name:
|
| 25 |
+
case 'draft_w':
|
| 26 |
+
novel_writer = DraftWriter(**kwargs)
|
| 27 |
+
case 'outline_w':
|
| 28 |
+
novel_writer = OutlineWriter(**kwargs)
|
| 29 |
+
case 'chapters_w' | 'plot_w':
|
| 30 |
+
novel_writer = PlotWriter(**kwargs)
|
| 31 |
+
case _:
|
| 32 |
+
raise ValueError(f"unknown writer: {current_w_name}")
|
| 33 |
+
|
| 34 |
+
return novel_writer
|
| 35 |
+
|
| 36 |
+
def dump_novel_writer(writer, novel_writer, apply_chunks={}, cost=0, currency_symbol='¥'):
|
| 37 |
+
new_writer = copy.deepcopy(writer) # TODO: dump从设计角度上来说,不应该更改原有的writer,但是在此处copy可能更耗时
|
| 38 |
+
|
| 39 |
+
current_w_name = new_writer['current_w']
|
| 40 |
+
current_w = new_writer[current_w_name]
|
| 41 |
+
|
| 42 |
+
# if current_w_name == 'draft_w':
|
| 43 |
+
# assert isinstance(novel_writer, DraftWriter), "draft_w需要传入DraftWriter"
|
| 44 |
+
|
| 45 |
+
current_w['xy_pairs'] = list(novel_writer.xy_pairs)
|
| 46 |
+
|
| 47 |
+
current_w['current_cost'] = cost
|
| 48 |
+
current_w['currency_symbol'] = currency_symbol
|
| 49 |
+
#current_w['total_cost'] += current_w['current_cost']
|
| 50 |
+
|
| 51 |
+
current_w['apply_chunks'] = apply_chunks
|
| 52 |
+
|
| 53 |
+
return new_writer
|
| 54 |
+
|
| 55 |
+
def call_write_long_novel(writer, setting):
|
| 56 |
+
writer = copy.deepcopy(writer)
|
| 57 |
+
progress = writer['progress']
|
| 58 |
+
|
| 59 |
+
if not progress or True:
|
| 60 |
+
progress = dict(
|
| 61 |
+
cur_op_i = progress['cur_op_i'] if progress else 0,
|
| 62 |
+
ops = [
|
| 63 |
+
{
|
| 64 |
+
'before_eval': 'writer["current_w"] = "outline_w"',
|
| 65 |
+
'eval': 'call_write(writer, setting, False, "构思全书的大致剧情,并将其以一个故事的形式写下来,只写大致情节。")',
|
| 66 |
+
'title': '创作大纲',
|
| 67 |
+
'subtitle': '生成大纲'
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
'eval': 'call_accept(writer, setting)',
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
'eval': 'call_write(writer, setting, True, "对整个情节进行重写,使其更加有故事性。")',
|
| 74 |
+
'title': '创作大纲',
|
| 75 |
+
'subtitle': '润色大纲',
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
'eval': 'call_accept(writer, setting)',
|
| 79 |
+
},
|
| 80 |
+
# 下面是创作剧情
|
| 81 |
+
{
|
| 82 |
+
'before_eval': 'init_chapters_w(writer)',
|
| 83 |
+
'eval': 'call_write(writer, setting, False, "丰富其中的剧情细节。")',
|
| 84 |
+
'title': '创作剧情',
|
| 85 |
+
'subtitle': '生成剧情'
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
'eval': 'call_accept(writer, setting)',
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
'eval': 'call_write(writer, setting, True, "对情节进行重写,使其有更多的剧情细节,同时更加有具有故事性。")',
|
| 92 |
+
'title': '创作剧情',
|
| 93 |
+
'subtitle': '扩充剧情',
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
'eval': 'call_accept(writer, setting)',
|
| 97 |
+
},
|
| 98 |
+
# 下面是创作正文
|
| 99 |
+
{
|
| 100 |
+
'before_eval': 'init_draft_w(writer)',
|
| 101 |
+
'eval': 'call_write(writer, setting, False, "创作的是正文,而不是剧情,需要像一个小说家那样去描写这个故事。")',
|
| 102 |
+
'title': '创作正文',
|
| 103 |
+
'subtitle': '生成正文'
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
'eval': 'call_accept(writer, setting)',
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
'eval': 'call_write(writer, setting, True, "润色正文")',
|
| 110 |
+
'title': '创作正文',
|
| 111 |
+
'subtitle': '润色正文'
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
'eval': 'call_accept(writer, setting)',
|
| 115 |
+
}
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# TODO: 考虑在init_plot时就给到上下文,类似rewrite_plot
|
| 120 |
+
|
| 121 |
+
title, subtitle = '', ''
|
| 122 |
+
for op in progress['ops']:
|
| 123 |
+
if 'title' not in op:
|
| 124 |
+
op['title'], op['subtitle'] = title, subtitle
|
| 125 |
+
else:
|
| 126 |
+
title, subtitle = op['title'], op['subtitle']
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
writer['progress'] = progress
|
| 130 |
+
yield writer
|
| 131 |
+
|
| 132 |
+
while progress['cur_op_i'] < len(progress['ops']):
|
| 133 |
+
current_op = progress['ops'][progress['cur_op_i']]
|
| 134 |
+
if 'before_eval' in current_op:
|
| 135 |
+
exec(current_op['before_eval'])
|
| 136 |
+
writer = yield from eval(current_op['eval'])
|
| 137 |
+
progress = writer['progress']
|
| 138 |
+
|
| 139 |
+
progress['cur_op_i'] += 1
|
| 140 |
+
yield writer # 当cur_op_i有更新时,也就标志着yield的是一个“稳定版本”的writer_state
|
| 141 |
+
|
| 142 |
+
return writer
|
| 143 |
+
|
| 144 |
+
def match_quote_text(writer, setting, quote_text):
|
| 145 |
+
novel_writer = load_novel_writer(writer, setting)
|
| 146 |
+
y_text = novel_writer.y
|
| 147 |
+
quote_text_span, match_ratio = match_span_by_char(y_text, quote_text)
|
| 148 |
+
if match_ratio > 0.5:
|
| 149 |
+
aligned_span, _ = novel_writer.align_span(y_span=quote_text_span)
|
| 150 |
+
return aligned_span, y_text[aligned_span[0]:aligned_span[1]]
|
| 151 |
+
else:
|
| 152 |
+
return None, ''
|
| 153 |
+
|
| 154 |
+
# 这是后端函数,接受前端writer_state的copy做为输入
|
| 155 |
+
# 返回的是修改后的writer_state,注意yield的值一般被用于前端展示执行的过程和进度
|
| 156 |
+
# 只有return值才会被前端考虑用于writer_state的更新
|
| 157 |
+
def call_write(writer, setting, auto_write=False, suggestion=None):
|
| 158 |
+
novel_writer = load_novel_writer(writer, setting)
|
| 159 |
+
|
| 160 |
+
current_w = writer[writer['current_w']]
|
| 161 |
+
current_w['xy_pairs'] = list(novel_writer.xy_pairs)
|
| 162 |
+
|
| 163 |
+
quote_span = writer['quote_span']
|
| 164 |
+
|
| 165 |
+
if auto_write:
|
| 166 |
+
assert quote_span is None, "auto_write模式下,不能有quote_text"
|
| 167 |
+
generator = novel_writer.auto_write()
|
| 168 |
+
else:
|
| 169 |
+
# TODO: writer.write 应该保证无论什么prompt,都能够同时适应y为空和y有值地情况
|
| 170 |
+
# 换句话说,就是虽然可以单列出一个“新建正文”,但用扩写正文也能实现同样的效果。
|
| 171 |
+
generator = novel_writer.write(suggestion, y_span=quote_span)
|
| 172 |
+
|
| 173 |
+
prompt_outputs = []
|
| 174 |
+
for kp_msg in generator:
|
| 175 |
+
if isinstance(kp_msg, KeyPointMsg):
|
| 176 |
+
# 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
|
| 177 |
+
yield kp_msg
|
| 178 |
+
continue
|
| 179 |
+
else:
|
| 180 |
+
chunk_list = kp_msg
|
| 181 |
+
|
| 182 |
+
current_cost = 0
|
| 183 |
+
apply_chunks = []
|
| 184 |
+
prompt_outputs.clear()
|
| 185 |
+
for output, chunk in chunk_list:
|
| 186 |
+
prompt_outputs.append(output)
|
| 187 |
+
current_text = ""
|
| 188 |
+
current_cost += output['response_msgs'].cost
|
| 189 |
+
currency_symbol = output['response_msgs'].currency_symbol
|
| 190 |
+
cost_info = f"\n(预计花费:{output['response_msgs'].cost:.4f}{output['response_msgs'].currency_symbol})"
|
| 191 |
+
if 'plot2text' in output:
|
| 192 |
+
current_text += f"正在建立映射关系..." + cost_info + '\n'
|
| 193 |
+
else:
|
| 194 |
+
current_text += output['text'] + cost_info + '\n'
|
| 195 |
+
apply_chunks.append((chunk, 'y_chunk', current_text))
|
| 196 |
+
|
| 197 |
+
new_writer = dump_novel_writer(writer, novel_writer, apply_chunks=apply_chunks, cost=current_cost, currency_symbol=currency_symbol)
|
| 198 |
+
new_writer['prompt_outputs'] = prompt_outputs
|
| 199 |
+
yield new_writer
|
| 200 |
+
|
| 201 |
+
# 这里是计算出一个编辑上的更改,方便前端显示,后续diff功能将不由writer提供,因为这是为了显示的要求
|
| 202 |
+
apply_chunks = []
|
| 203 |
+
for chunk, key, value in load_novel_writer(writer, setting).diff_to(novel_writer):
|
| 204 |
+
apply_chunks.append((chunk, key, value))
|
| 205 |
+
writer[writer['current_w']]['apply_chunks'] = apply_chunks
|
| 206 |
+
writer['prompt_outputs'] = prompt_outputs
|
| 207 |
+
return writer
|
| 208 |
+
|
| 209 |
+
def call_accept(writer, setting):
|
| 210 |
+
current_w_name = writer['current_w']
|
| 211 |
+
current_w = writer[current_w_name]
|
| 212 |
+
|
| 213 |
+
novel_writer = load_novel_writer(writer, setting)
|
| 214 |
+
for chunk, key, text in current_w['apply_chunks']:
|
| 215 |
+
novel_writer.apply_chunk(chunk, key, text)
|
| 216 |
+
|
| 217 |
+
writer = dump_novel_writer(writer, novel_writer)
|
| 218 |
+
return writer
|
core/diff_utils.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import difflib
|
| 2 |
+
from difflib import SequenceMatcher
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def match_span_by_char(text, chunk):
|
| 6 |
+
# 用来存储从text中找到的符合匹配的行的span
|
| 7 |
+
spans = []
|
| 8 |
+
|
| 9 |
+
# 使用difflib来寻找最佳匹配行
|
| 10 |
+
matcher = difflib.SequenceMatcher(None, text, chunk)
|
| 11 |
+
|
| 12 |
+
# 获取匹配块信息
|
| 13 |
+
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
| 14 |
+
if tag == 'equal':
|
| 15 |
+
# 记录匹配行的起始和结束索引
|
| 16 |
+
spans.append((i1, i2))
|
| 17 |
+
|
| 18 |
+
if spans:
|
| 19 |
+
match_span = (spans[0][0], spans[-1][1])
|
| 20 |
+
match_ratio = sum(i2 - i1 for i1, i2 in spans) / len(chunk)
|
| 21 |
+
return match_span, match_ratio
|
| 22 |
+
else:
|
| 23 |
+
return None, 0
|
| 24 |
+
|
| 25 |
+
def match_sequences(a_list, b_list):
|
| 26 |
+
"""
|
| 27 |
+
匹配两个字符串列表,返回匹配的索引对
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
a_list: 第一个字符串列表
|
| 31 |
+
b_list: 第二个字符串列表
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
list[((l,r), (j,k))]: 匹配的索引对列表,
|
| 35 |
+
其中(l,r)表示a_list的起始和结束索引,(j,k)表示b_list的起始和结束索引
|
| 36 |
+
"""
|
| 37 |
+
m, n = len(a_list) - 1, len(b_list) - 1
|
| 38 |
+
matches = []
|
| 39 |
+
i = j = 0
|
| 40 |
+
|
| 41 |
+
while i < m and j < n:
|
| 42 |
+
# 初始化当前最佳匹配
|
| 43 |
+
best_match = None
|
| 44 |
+
best_ratio = -1 # 设置匹配阈值
|
| 45 |
+
|
| 46 |
+
# 尝试从当前位置开始的不同组合
|
| 47 |
+
for l in range(i, min(i + 3, m)): # 限制向前查找的范围
|
| 48 |
+
current_a = ''.join(a_list[i:l + 1])
|
| 49 |
+
|
| 50 |
+
for r in range(j, min(j + 3, n)): # 限制向前查找的范围
|
| 51 |
+
current_b = ''.join(b_list[j:r + 1])
|
| 52 |
+
|
| 53 |
+
# 使用已有的match_span_by_char函数计算匹配度
|
| 54 |
+
span1, ratio1 = match_span_by_char(current_b, current_a)
|
| 55 |
+
span2, ratio2 = match_span_by_char(current_a, current_b)
|
| 56 |
+
ratio = ratio1 * ratio2
|
| 57 |
+
|
| 58 |
+
if ratio > best_ratio:
|
| 59 |
+
best_ratio = ratio
|
| 60 |
+
best_match = ((i, l + 1), (j, r + 1))
|
| 61 |
+
|
| 62 |
+
if best_match:
|
| 63 |
+
matches.append(best_match)
|
| 64 |
+
i = best_match[0][1]
|
| 65 |
+
j = best_match[1][1]
|
| 66 |
+
else:
|
| 67 |
+
# 如果没找到好的匹配,向前移动一步
|
| 68 |
+
i += 1
|
| 69 |
+
j += 1
|
| 70 |
+
|
| 71 |
+
matches.append(((i, m+1), (j, n+1)))
|
| 72 |
+
|
| 73 |
+
return matches
|
| 74 |
+
|
| 75 |
+
def get_chunk_changes(source_chunk_list, target_chunk_list):
|
| 76 |
+
SEPARATOR = "%|%"
|
| 77 |
+
source_text = SEPARATOR.join(source_chunk_list)
|
| 78 |
+
target_text = SEPARATOR.join(target_chunk_list)
|
| 79 |
+
|
| 80 |
+
# 初始化每个chunk的tag统计
|
| 81 |
+
source_chunk_stats = [{'delete_or_insert': 0, 'replace_or_equal': 0} for _ in source_chunk_list]
|
| 82 |
+
target_chunk_stats = [{'delete_or_insert': 0, 'replace_or_equal': 0} for _ in target_chunk_list]
|
| 83 |
+
|
| 84 |
+
# 获取chunk的起始位置列表
|
| 85 |
+
source_positions = [0]
|
| 86 |
+
target_positions = [0]
|
| 87 |
+
pos = 0
|
| 88 |
+
for chunk in source_chunk_list[:-1]:
|
| 89 |
+
pos += len(chunk) + len(SEPARATOR)
|
| 90 |
+
source_positions.append(pos)
|
| 91 |
+
source_positions.append(len(source_text))
|
| 92 |
+
|
| 93 |
+
pos = 0
|
| 94 |
+
for chunk in target_chunk_list[:-1]:
|
| 95 |
+
pos += len(chunk) + len(SEPARATOR)
|
| 96 |
+
target_positions.append(pos)
|
| 97 |
+
target_positions.append(len(target_text))
|
| 98 |
+
|
| 99 |
+
def update_chunk_stats(positions, stats, start, end, tag):
|
| 100 |
+
for i in range(len(positions) - 1):
|
| 101 |
+
chunk_start = positions[i]
|
| 102 |
+
chunk_end = positions[i + 1]
|
| 103 |
+
|
| 104 |
+
overlap_start = max(chunk_start, start)
|
| 105 |
+
overlap_end = min(chunk_end, end)
|
| 106 |
+
|
| 107 |
+
if overlap_end > overlap_start:
|
| 108 |
+
stats[i][tag] += overlap_end - overlap_start
|
| 109 |
+
|
| 110 |
+
matcher = SequenceMatcher(None, source_text, target_text)
|
| 111 |
+
|
| 112 |
+
# 处理每个操作块并更新统计信息
|
| 113 |
+
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
| 114 |
+
if tag == 'replace' or tag == 'equal':
|
| 115 |
+
update_chunk_stats(source_positions, source_chunk_stats, i1, i2, 'replace_or_equal')
|
| 116 |
+
update_chunk_stats(target_positions, target_chunk_stats, j1, j2, 'replace_or_equal')
|
| 117 |
+
elif tag == 'delete':
|
| 118 |
+
update_chunk_stats(source_positions, source_chunk_stats, i1, i2, 'delete_or_insert')
|
| 119 |
+
elif tag == 'insert':
|
| 120 |
+
update_chunk_stats(target_positions, target_chunk_stats, j1, j2, 'delete_or_insert')
|
| 121 |
+
|
| 122 |
+
# 确定每个chunk的最终tag
|
| 123 |
+
def get_final_tag(stats):
|
| 124 |
+
return 'delete_or_insert' if stats['delete_or_insert'] > stats['replace_or_equal'] else 'replace_or_equal'
|
| 125 |
+
|
| 126 |
+
source_chunk_tags = [get_final_tag(stats) for stats in source_chunk_stats]
|
| 127 |
+
target_chunk_tags = [get_final_tag(stats) for stats in target_chunk_stats]
|
| 128 |
+
|
| 129 |
+
# 使用双指针计算changes
|
| 130 |
+
changes = []
|
| 131 |
+
i = j = 0 # i指向source_chunk_list,j指向target_chunk_list
|
| 132 |
+
start_i = start_j = 0
|
| 133 |
+
m, n = len(source_chunk_list), len(target_chunk_list)
|
| 134 |
+
while i < m or j < n:
|
| 135 |
+
if i < m and source_chunk_tags[i] == 'delete_or_insert':
|
| 136 |
+
while i < m and source_chunk_tags[i] == 'delete_or_insert': i += 1
|
| 137 |
+
elif j < n and target_chunk_tags[j] == 'delete_or_insert':
|
| 138 |
+
while j < n and target_chunk_tags[j] == 'delete_or_insert': j += 1
|
| 139 |
+
elif i < m and j < n and source_chunk_tags[i] == 'replace_or_equal' and target_chunk_tags[j] == 'replace_or_equal':
|
| 140 |
+
while i < m and j < n and source_chunk_tags[i] == 'replace_or_equal' and target_chunk_tags[j] == 'replace_or_equal':
|
| 141 |
+
i += 1
|
| 142 |
+
j += 1
|
| 143 |
+
else:
|
| 144 |
+
# TODO: 这个算法目前还有一些问题,即equal的对应
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
# 当有任意一个指针移动时,检查是否需要添加change
|
| 148 |
+
if (i > start_i or j > start_j):
|
| 149 |
+
changes.append((start_i, i, start_j, j))
|
| 150 |
+
start_i, start_j = i, j
|
| 151 |
+
|
| 152 |
+
if (i < m or j < n):
|
| 153 |
+
changes.append((start_i, m, start_j, n))
|
| 154 |
+
|
| 155 |
+
return changes
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# 使用示例
|
| 159 |
+
def test_get_chunk_changes():
|
| 160 |
+
source_chunks = ['', '', '', '第3章 初露锋芒\n在高人指导下,萧炎的斗气水平迅速提升,开始在家族中引起注意。\n', '', '第4章 异火初现\n萧炎得知“异火”的存在,决定踏上寻找异火的旅程。\n']
|
| 161 |
+
target_chunks = ['', '第3章 初露锋芒\n在高人指导下,萧炎的斗气水平迅速提升,开始在家族中引起注意。', '第3.5章 家族试炼\n萧炎参加家族举办的试炼,凭借新学的斗技和炼丹术,展现出超凡实力,获得家族长老的关注和认可。', '第4章 异火初现\n萧炎得知“异火”的存在,决定踏上寻找异火的旅程。']
|
| 162 |
+
|
| 163 |
+
changes = get_chunk_changes(source_chunks, target_chunks)
|
| 164 |
+
for change in changes:
|
| 165 |
+
print(f"Source chunks {change[0]}:{change[1]} -> Target chunks {change[2]}:{change[3]}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
for change in changes:
|
| 169 |
+
print('-' * 20)
|
| 170 |
+
print(f"{''.join(source_chunks[change[0]:change[1]])} -> {''.join(target_chunks[change[2]:change[3]])}")
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
test_get_chunk_changes()
|
core/draft_writer.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core.writer_utils import KeyPointMsg
|
| 2 |
+
from core.writer import Writer
|
| 3 |
+
|
| 4 |
+
from prompts.创作正文.prompt import main as prompt_draft
|
| 5 |
+
from prompts.提炼.prompt import main as prompt_summary
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DraftWriter(Writer):
|
| 9 |
+
def __init__(self, xy_pairs, global_context, model=None, sub_model=None, x_chunk_length=500, y_chunk_length=1000, max_thread_num=5):
|
| 10 |
+
super().__init__(xy_pairs, global_context, model, sub_model, x_chunk_length=x_chunk_length, y_chunk_length=y_chunk_length, max_thread_num=max_thread_num)
|
| 11 |
+
|
| 12 |
+
def write(self, user_prompt, pair_span=None):
|
| 13 |
+
target_chunk = self.get_chunk(pair_span=pair_span)
|
| 14 |
+
if not target_chunk.x_chunk:
|
| 15 |
+
raise Exception("需要提供剧情。")
|
| 16 |
+
if len(target_chunk.x_chunk) <= 5:
|
| 17 |
+
raise Exception("剧情不能少于5个字。")
|
| 18 |
+
|
| 19 |
+
chunks = self.get_chunks(pair_span)
|
| 20 |
+
|
| 21 |
+
yield from self.batch_write_apply_text(chunks, prompt_draft, user_prompt)
|
| 22 |
+
|
| 23 |
+
def summary(self, pair_span=None):
|
| 24 |
+
target_chunk = self.get_chunk(pair_span=pair_span)
|
| 25 |
+
if not target_chunk.y_chunk:
|
| 26 |
+
raise Exception("没有正文需要总结。")
|
| 27 |
+
if len(target_chunk.y_chunk) <= 5:
|
| 28 |
+
raise Exception("需要总结的正文不能少于5个字。")
|
| 29 |
+
|
| 30 |
+
# 先分割为更小的块,这样get_chunks才能正常工作
|
| 31 |
+
new_target_chunk = self.map_text_wo_llm(target_chunk)
|
| 32 |
+
self.apply_chunks([target_chunk], [new_target_chunk])
|
| 33 |
+
chunk_span = self.get_chunk_pair_span(new_target_chunk)
|
| 34 |
+
|
| 35 |
+
chunks = self.get_chunks(chunk_span, context_length_ratio=0)
|
| 36 |
+
|
| 37 |
+
yield from self.batch_write_apply_text(chunks, prompt_summary, "提炼剧情")
|
| 38 |
+
|
| 39 |
+
def split_into_chapters(self):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def get_model(self):
|
| 43 |
+
return self.model
|
| 44 |
+
|
| 45 |
+
def get_sub_model(self):
|
| 46 |
+
return self.sub_model
|
core/frontend.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from rich.traceback import install
|
| 3 |
+
install(show_locals=False)
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import yaml
|
| 7 |
+
import functools
|
| 8 |
+
import time
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import copy
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 15 |
+
|
| 16 |
+
from config import RENDER_SAVE_LOAD_BTN, RENDER_STOP_BTN
|
| 17 |
+
from core.backend import call_write, call_accept, match_quote_text
|
| 18 |
+
from core.frontend_copy import enable_copy_js, on_copy
|
| 19 |
+
from core.frontend_setting import new_setting, render_setting
|
| 20 |
+
from core.frontend_utils import (
|
| 21 |
+
title, info,
|
| 22 |
+
create_progress_md, create_text_md, messages2chatbot,
|
| 23 |
+
init_writer, has_accept, is_running, try_cancel, writer_y_is_empty, writer_x_is_empty,
|
| 24 |
+
cancellable, process_writer_to_backend, process_writer_from_backend,
|
| 25 |
+
init_chapters_w, init_draft_w
|
| 26 |
+
)
|
| 27 |
+
from core.writer_utils import KeyPointMsg
|
| 28 |
+
|
| 29 |
+
from prompts.baseprompt import clean_txt_content, load_prompt
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# 读取YAML文件
|
| 33 |
+
with open('prompts/idea-examples.yaml', 'r', encoding='utf-8') as file:
|
| 34 |
+
examples_data = yaml.safe_load(file)
|
| 35 |
+
|
| 36 |
+
# 准备示例列表
|
| 37 |
+
examples = [[example['idea']] for example in examples_data['examples']]
|
| 38 |
+
|
| 39 |
+
with gr.Blocks(head=enable_copy_js) as demo:
|
| 40 |
+
gr.HTML(title)
|
| 41 |
+
with gr.Accordion("使用指南"):
|
| 42 |
+
gr.Markdown(info)
|
| 43 |
+
|
| 44 |
+
writer_state = gr.State(init_writer('', check_empty=False))
|
| 45 |
+
setting_state = gr.State(new_setting())
|
| 46 |
+
|
| 47 |
+
if RENDER_SAVE_LOAD_BTN:
|
| 48 |
+
with gr.Row():
|
| 49 |
+
save_button = gr.Button("保存状态")
|
| 50 |
+
load_button = gr.Button("加载状态")
|
| 51 |
+
save_file_name = gr.Textbox(value='states', placeholder='输入文件名', lines=1, label=None, show_label=False, container=False)
|
| 52 |
+
|
| 53 |
+
def save_states(save_file_name, writer, setting):
|
| 54 |
+
import json
|
| 55 |
+
json_file_name = save_file_name + '.json'
|
| 56 |
+
with open(json_file_name, 'w', encoding='utf-8') as f:
|
| 57 |
+
json.dump({
|
| 58 |
+
'writer': writer,
|
| 59 |
+
'setting': setting
|
| 60 |
+
}, f, ensure_ascii=False, indent=2)
|
| 61 |
+
gr.Info(f"状态已保存到文件:{json_file_name}")
|
| 62 |
+
|
| 63 |
+
def load_states(save_file_name):
|
| 64 |
+
import json
|
| 65 |
+
json_file_name = save_file_name + '.json'
|
| 66 |
+
try:
|
| 67 |
+
with open(json_file_name, 'r', encoding='utf-8') as f:
|
| 68 |
+
states = json.load(f)
|
| 69 |
+
gr.Info(f"状态文件已加载:{json_file_name}")
|
| 70 |
+
states['setting']['render_time'] = time.time()
|
| 71 |
+
# 为了确保setting被渲染,选择模型是不会赋值setting_state的
|
| 72 |
+
# 需要保证setting界面持有的对象和setting_state是同一个
|
| 73 |
+
return states['writer'], states['setting']
|
| 74 |
+
except FileNotFoundError:
|
| 75 |
+
raise gr.Error(f"未找到保存的状态文件:{json_file_name}")
|
| 76 |
+
|
| 77 |
+
idea_textbox = gr.Textbox(placeholder='用一段话描述你要写的小说,或者从下方示例中选择一个创意...', lines=2, scale=1, label=None, show_label=False, container=False, max_length=1000)
|
| 78 |
+
|
| 79 |
+
gr.Examples(
|
| 80 |
+
label='示例',
|
| 81 |
+
examples=examples,
|
| 82 |
+
inputs=[idea_textbox],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
with gr.Row():
|
| 86 |
+
outline_btn = gr.Button("创作大纲", scale=1, min_width=1, interactive = True, variant='primary')
|
| 87 |
+
chapters_btn = gr.Button("创作剧情", scale=1, min_width=1, interactive = False, variant='secondary')
|
| 88 |
+
draft_btn = gr.Button("创作正文", scale=1, min_width=1, interactive = False, variant='secondary')
|
| 89 |
+
auto_checkbox = gr.Checkbox(label='一键生成', scale=1, value=False, visible=False) # TODO: V1.10版本 “自动”尚不完善,暂不显示
|
| 90 |
+
|
| 91 |
+
progress_md = create_progress_md(writer_state.value)
|
| 92 |
+
text_md = create_text_md(writer_state.value)
|
| 93 |
+
|
| 94 |
+
@gr.render(inputs=writer_state)
|
| 95 |
+
def create_prompt_preview(writer):
|
| 96 |
+
prompt_outputs = writer['prompt_outputs'] if 'prompt_outputs' in writer else []
|
| 97 |
+
with gr.Accordion("Prompt预览", open=bool(prompt_outputs)):
|
| 98 |
+
pause_on_prompt_finished_checkbox = gr.Checkbox(label='允许在LLM响应完成后,预览Prompt', scale=1, value=writer['pause_on_prompt_finished_flag'])
|
| 99 |
+
|
| 100 |
+
for i, prompt_output in enumerate(prompt_outputs, 1):
|
| 101 |
+
with gr.Tab(f"Prompt {i}"):
|
| 102 |
+
gr.Chatbot(messages2chatbot(prompt_output['response_msgs']), type='messages')
|
| 103 |
+
if not prompt_outputs:
|
| 104 |
+
gr.Markdown('当前没有可预览的Prompt。')
|
| 105 |
+
|
| 106 |
+
continue_btn = gr.Button('继续', visible=bool(prompt_outputs), variant='primary')
|
| 107 |
+
|
| 108 |
+
def on_pause_on_prompt_finished(value):
|
| 109 |
+
if value:
|
| 110 |
+
gr.Info("在LLM响应完成后,将可以预览Prompt")
|
| 111 |
+
writer['pause_on_prompt_finished_flag'] = value
|
| 112 |
+
|
| 113 |
+
pause_on_prompt_finished_checkbox.change(on_pause_on_prompt_finished, [pause_on_prompt_finished_checkbox])
|
| 114 |
+
|
| 115 |
+
def on_continue(writer):
|
| 116 |
+
writer['pause_flag'] = False
|
| 117 |
+
writer['prompt_outputs'] = []
|
| 118 |
+
return writer
|
| 119 |
+
|
| 120 |
+
continue_btn.click(on_continue, writer_state, writer_state)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
with gr.Row():
|
| 124 |
+
rewrite_all_button = gr.Button("开始创作", min_width=100, scale=2, variant='secondary', interactive=False)
|
| 125 |
+
suggestion_dropdown = gr.Dropdown(choices=[], min_width=100, scale=2, label=None, show_label=False, container=False, allow_custom_value=False)
|
| 126 |
+
quote_checkbox = gr.Checkbox(label='允许引用', min_width=100, scale=2, value=False)
|
| 127 |
+
gr.Textbox('窗口大小:', container=False, text_align='right', scale=1, min_width=100)
|
| 128 |
+
chunk_length_dropdown = gr.Dropdown(choices=[], min_width=80, scale=1, label=None, show_label=False, container=False, allow_custom_value=False)
|
| 129 |
+
|
| 130 |
+
quote_md = gr.Markdown(visible=False)
|
| 131 |
+
|
| 132 |
+
def on_quote_checkbox_change(writer, value):
|
| 133 |
+
if writer['current_w'] == 'outline_w':
|
| 134 |
+
gr.Info("大纲创作不支持引用\n考虑在剧情和正文创作中使用吧~")
|
| 135 |
+
return gr.update(value=False, visible=False)
|
| 136 |
+
|
| 137 |
+
if value:
|
| 138 |
+
gr.Info("允许引用(右键或Ctrl+C复制你想引用的文本)")
|
| 139 |
+
writer['quote_span'] = None
|
| 140 |
+
writer['quoted_text'] = ''
|
| 141 |
+
return gr.update(value=None, visible=False)
|
| 142 |
+
|
| 143 |
+
quote_checkbox.change(on_quote_checkbox_change, [writer_state, quote_checkbox], [quote_md])
|
| 144 |
+
|
| 145 |
+
def on_chunk_length_change(writer, value):
|
| 146 |
+
current_w_name = writer['current_w']
|
| 147 |
+
writer[current_w_name]['y_chunk_length'] = value
|
| 148 |
+
return gr.update(value=value)
|
| 149 |
+
|
| 150 |
+
chunk_length_dropdown.change(on_chunk_length_change, [writer_state, chunk_length_dropdown], [chunk_length_dropdown])
|
| 151 |
+
|
| 152 |
+
def on_copy_handle(text, writer, setting, quote_checkbox):
|
| 153 |
+
# gr.Info(f"Copy: {text}")
|
| 154 |
+
text = text.strip()
|
| 155 |
+
|
| 156 |
+
if has_accept(writer):
|
| 157 |
+
gr.Info('考虑先接受或拒绝修改哦~')
|
| 158 |
+
return gr.update(visible=False)
|
| 159 |
+
|
| 160 |
+
if len(text) < 10:
|
| 161 |
+
gr.Info('选中的文本太短,无法引用')
|
| 162 |
+
return gr.update(visible=False)
|
| 163 |
+
|
| 164 |
+
if quote_checkbox:
|
| 165 |
+
quote_span, quoted_text = match_quote_text(writer, setting, text)
|
| 166 |
+
if quote_span:
|
| 167 |
+
writer['quote_span'] = quote_span
|
| 168 |
+
writer['quoted_text'] = quoted_text
|
| 169 |
+
lines = quoted_text.split('\n')
|
| 170 |
+
if len(lines) > 10:
|
| 171 |
+
lines[5:-5] = ['......']
|
| 172 |
+
lines = ['```', ] + lines + ['```', ]
|
| 173 |
+
quoted_text = '\n'.join(["> " + e for e in lines])
|
| 174 |
+
return gr.update(value=quoted_text, visible=True)
|
| 175 |
+
else:
|
| 176 |
+
gr.Info('未找到匹配的引用文本')
|
| 177 |
+
|
| 178 |
+
writer['quote_span'] = None
|
| 179 |
+
writer['quoted_text'] = ''
|
| 180 |
+
return gr.update(visible=False)
|
| 181 |
+
|
| 182 |
+
on_copy(on_copy_handle, [writer_state, setting_state, quote_checkbox], [quote_md])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
suggestion_textbox = gr.Textbox(max_length=1000, placeholder='在这里输入你的意见,或者从右上单选框选择', lines=2, scale=1, label=None, show_label=False, container=False)
|
| 186 |
+
|
| 187 |
+
with gr.Row():
|
| 188 |
+
accept_button = gr.Button("接受", scale=1, min_width=1, variant='secondary', interactive=False)
|
| 189 |
+
pause_button = gr.Button("暂停", scale=1, min_width=1, variant='secondary', visible=RENDER_STOP_BTN)
|
| 190 |
+
stop_button = gr.Button("取消", scale=1, min_width=1, variant='secondary')
|
| 191 |
+
flash_button = gr.Button("刷新", scale=1, min_width=1, variant='secondary')
|
| 192 |
+
|
| 193 |
+
def flash_interface(writer):
|
| 194 |
+
current_w_name = writer['current_w']
|
| 195 |
+
|
| 196 |
+
can_accept_flag = has_accept(writer) and not is_running(writer)
|
| 197 |
+
can_write_flag = not writer_x_is_empty(writer, current_w_name) and not can_accept_flag
|
| 198 |
+
|
| 199 |
+
match current_w_name:
|
| 200 |
+
case 'outline_w':
|
| 201 |
+
rewrite_all_button = gr.update(value='开始创作', variant='primary' if can_write_flag else 'secondary', interactive=can_write_flag)
|
| 202 |
+
case 'chapters_w':
|
| 203 |
+
rewrite_all_button = gr.update(value='开始创作', variant='primary' if can_write_flag else 'secondary', interactive=can_write_flag)
|
| 204 |
+
case 'draft_w':
|
| 205 |
+
rewrite_all_button = gr.update(value='开始创作', variant='primary' if can_write_flag else 'secondary', interactive=can_write_flag)
|
| 206 |
+
|
| 207 |
+
accept_button = gr.update(variant='primary' if can_accept_flag else 'secondary', interactive=can_accept_flag)
|
| 208 |
+
|
| 209 |
+
# 更新 chapters_btn 和 draft_btn 的 interactive 状态
|
| 210 |
+
outline_btn = gr.update(
|
| 211 |
+
variant='primary' if current_w_name == 'outline_w' else 'secondary'
|
| 212 |
+
)
|
| 213 |
+
chapters_btn = gr.update(
|
| 214 |
+
interactive=not writer_y_is_empty(writer, 'outline_w'),
|
| 215 |
+
variant='primary' if current_w_name == 'chapters_w' else 'secondary'
|
| 216 |
+
)
|
| 217 |
+
draft_btn = gr.update(
|
| 218 |
+
interactive=not writer_y_is_empty(writer, 'chapters_w'),
|
| 219 |
+
variant='primary' if current_w_name == 'draft_w' else 'secondary'
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
pause_button = gr.update(
|
| 223 |
+
value="继续" if writer['pause_flag'] else "暂停",
|
| 224 |
+
variant='secondary',
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
suggestion_choices = writer['suggestions'][current_w_name]
|
| 228 |
+
# suggestion_choices = ['自动', ] + writer['suggestions'][current_w_name] # TODO: V1.10版本 “自动”尚不完善,暂不显示
|
| 229 |
+
if writer_y_is_empty(writer, current_w_name):
|
| 230 |
+
suggestion_dropdown = gr.update(choices=suggestion_choices, value=suggestion_choices[0])
|
| 231 |
+
else:
|
| 232 |
+
suggestion_dropdown = gr.update(choices=suggestion_choices,)
|
| 233 |
+
|
| 234 |
+
chunk_length_choices = writer['chunk_length'][current_w_name]
|
| 235 |
+
if cur_chunk_length := writer[current_w_name].get('y_chunk_length', None):
|
| 236 |
+
chunk_length_dropdown = gr.update(choices=chunk_length_choices, value=cur_chunk_length)
|
| 237 |
+
else:
|
| 238 |
+
chunk_length_dropdown = gr.update(choices=chunk_length_choices, value=chunk_length_choices[0])
|
| 239 |
+
|
| 240 |
+
return (
|
| 241 |
+
create_text_md(writer),
|
| 242 |
+
create_progress_md(writer),
|
| 243 |
+
rewrite_all_button,
|
| 244 |
+
accept_button,
|
| 245 |
+
outline_btn,
|
| 246 |
+
chapters_btn,
|
| 247 |
+
draft_btn,
|
| 248 |
+
pause_button,
|
| 249 |
+
suggestion_dropdown,
|
| 250 |
+
chunk_length_dropdown
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# 更新 flash_event 字典以包含新的输出
|
| 254 |
+
flash_event = dict(
|
| 255 |
+
fn=flash_interface,
|
| 256 |
+
inputs=[writer_state],
|
| 257 |
+
outputs=[
|
| 258 |
+
text_md,
|
| 259 |
+
progress_md,
|
| 260 |
+
rewrite_all_button,
|
| 261 |
+
accept_button,
|
| 262 |
+
outline_btn,
|
| 263 |
+
chapters_btn,
|
| 264 |
+
draft_btn,
|
| 265 |
+
pause_button,
|
| 266 |
+
suggestion_dropdown,
|
| 267 |
+
chunk_length_dropdown
|
| 268 |
+
]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
flash_button.click(**flash_event)
|
| 272 |
+
if RENDER_SAVE_LOAD_BTN:
|
| 273 |
+
save_button.click(save_states, inputs=[save_file_name, writer_state, setting_state], outputs=[])
|
| 274 |
+
load_button.click(load_states, inputs=[save_file_name], outputs=[writer_state, setting_state]).success(**flash_event)
|
| 275 |
+
# stop_write_long_novel_button.click(on_cancel, inputs=[writer_state])
|
| 276 |
+
stop_button.click(try_cancel, inputs=[writer_state]).success(**flash_event).success(lambda :gr.update(), None, writer_state)
|
| 277 |
+
# TODO: stop_btn对writer_state的更新没有起效
|
| 278 |
+
|
| 279 |
+
@cancellable
|
| 280 |
+
def _on_write_all(writer, setting, auto_write=False, suggestion=None):
|
| 281 |
+
current_w_name = writer['current_w']
|
| 282 |
+
|
| 283 |
+
if writer_x_is_empty(writer, current_w_name):
|
| 284 |
+
gr.Info('请先输入需要创作的内容!')
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
writer['prompt_outputs'].clear()
|
| 288 |
+
|
| 289 |
+
if writer['quote_span']:
|
| 290 |
+
quote_span, quoted_text = match_quote_text(writer, setting, writer['quoted_text'])
|
| 291 |
+
if quote_span != writer['quote_span'] or quoted_text != writer['quoted_text']:
|
| 292 |
+
raise gr.Error('引用文本不存在!')
|
| 293 |
+
|
| 294 |
+
generator = call_write(process_writer_to_backend(writer), setting, auto_write, suggestion)
|
| 295 |
+
|
| 296 |
+
new_writer = None
|
| 297 |
+
while True:
|
| 298 |
+
try:
|
| 299 |
+
kp_msg = next(generator)
|
| 300 |
+
if isinstance(kp_msg, KeyPointMsg):
|
| 301 |
+
# TODO: 由于KeyPointMsg的设计问题,这里的逻辑比较复杂,后续可以考虑优化
|
| 302 |
+
if kp_msg.is_prompt() and kp_msg.is_finished() and writer['pause_on_prompt_finished_flag']:
|
| 303 |
+
gr.Info('LLM响应完成,可以预览Prompt')
|
| 304 |
+
writer['pause_flag'] = True
|
| 305 |
+
if new_writer is None: continue
|
| 306 |
+
elif kp_msg.is_title(): # TODO: 标题节点还未实现finish逻辑
|
| 307 |
+
# if new_writer is not None:
|
| 308 |
+
# # 说明这是一个关键节点,进行保存
|
| 309 |
+
# process_writer_from_backend(writer, new_writer)
|
| 310 |
+
# yield create_text_md(writer), writer
|
| 311 |
+
# gr.Info(f'已自动保存进度')
|
| 312 |
+
continue
|
| 313 |
+
# 关键节点保存的逻辑比较复杂,有bug,之后版本考虑提供
|
| 314 |
+
else:
|
| 315 |
+
continue
|
| 316 |
+
else:
|
| 317 |
+
new_writer = kp_msg
|
| 318 |
+
|
| 319 |
+
if writer['pause_flag']:
|
| 320 |
+
writer['prompt_outputs'] = copy.deepcopy(new_writer['prompt_outputs'])
|
| 321 |
+
# 将prompt_outputs传递到writer_state中,使得暂停时能显示prompt, 需要序列化,否则writer会不断更新,导致prompt不断渲染
|
| 322 |
+
yield create_text_md(new_writer), writer
|
| 323 |
+
|
| 324 |
+
while writer['pause_flag'] and not writer['cancel_flag']:
|
| 325 |
+
time.sleep(0.1)
|
| 326 |
+
else:
|
| 327 |
+
yield create_text_md(new_writer), gr.update()
|
| 328 |
+
except StopIteration as e:
|
| 329 |
+
# 这里处理最终状态
|
| 330 |
+
process_writer_from_backend(writer, e.value)
|
| 331 |
+
yield create_text_md(writer), writer
|
| 332 |
+
if has_accept(writer):
|
| 333 |
+
gr.Info('创作完成!点击接受按钮接受修改。')
|
| 334 |
+
else:
|
| 335 |
+
gr.Info('本次创作没有任何更改。') # 通常因为审阅意见认为无需更改
|
| 336 |
+
return
|
| 337 |
+
|
| 338 |
+
def on_auto_write_all(writer, setting, auto_write):
|
| 339 |
+
if auto_write:
|
| 340 |
+
yield from _on_write_all(writer, setting, True)
|
| 341 |
+
else:
|
| 342 |
+
pass
|
| 343 |
+
# suggestion = writer['suggestions'][writer['current_w']][0]
|
| 344 |
+
# yield from _on_write_all(writer, setting, False, suggestion)
|
| 345 |
+
|
| 346 |
+
writer_all_events = dict(
|
| 347 |
+
fn=on_auto_write_all,
|
| 348 |
+
queue=True,
|
| 349 |
+
inputs=[writer_state, setting_state, auto_checkbox],
|
| 350 |
+
outputs=[text_md, writer_state],
|
| 351 |
+
concurrency_limit=10
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def on_init_outline(idea, writer):
|
| 355 |
+
if not idea.strip():
|
| 356 |
+
gr.Info("先输入小说简介或从示例中选择一个")
|
| 357 |
+
return gr.update()
|
| 358 |
+
new_writer = init_writer(idea)
|
| 359 |
+
writer.update({
|
| 360 |
+
k:v for k, v in new_writer.items() if k in ['current_w', 'outline_w', 'prompt_outputs']
|
| 361 |
+
})
|
| 362 |
+
return writer
|
| 363 |
+
|
| 364 |
+
outline_btn.click(on_init_outline, inputs=[idea_textbox, writer_state], outputs=[writer_state]).success(**writer_all_events).then(**flash_event)
|
| 365 |
+
chapters_btn.click(lambda writer: init_chapters_w(writer), inputs=[writer_state], outputs=[writer_state]).success(**writer_all_events).then(**flash_event)
|
| 366 |
+
draft_btn.click(lambda writer: init_draft_w(writer), inputs=[writer_state], outputs=[writer_state]).success(**writer_all_events).then(**flash_event)
|
| 367 |
+
|
| 368 |
+
def on_select_suggestion(writer, setting, choice):
|
| 369 |
+
if choice == '自动':
|
| 370 |
+
return gr.update(value=choice, visible=False)
|
| 371 |
+
|
| 372 |
+
current_w_name = writer['current_w']
|
| 373 |
+
dirname = writer['suggestions_dirname'][current_w_name]
|
| 374 |
+
suggestion = clean_txt_content(load_prompt(dirname, choice))
|
| 375 |
+
if suggestion.startswith("user:\n"):
|
| 376 |
+
suggestion = suggestion[len("user:\n"):]
|
| 377 |
+
|
| 378 |
+
return gr.update(value=suggestion, visible=True)
|
| 379 |
+
|
| 380 |
+
suggestion_dropdown.change(on_select_suggestion, inputs=[writer_state, setting_state, suggestion_dropdown], outputs=[suggestion_textbox])
|
| 381 |
+
|
| 382 |
+
def on_write_all(writer, setting, suggestion):
|
| 383 |
+
if not suggestion.strip():
|
| 384 |
+
gr.Info('需要输入创作意见!')
|
| 385 |
+
return
|
| 386 |
+
yield from _on_write_all(writer, setting, False, suggestion)
|
| 387 |
+
|
| 388 |
+
rewrite_all_button.click(
|
| 389 |
+
on_write_all,
|
| 390 |
+
queue=True,
|
| 391 |
+
inputs=[writer_state, setting_state, suggestion_textbox],
|
| 392 |
+
outputs=[text_md, writer_state],
|
| 393 |
+
concurrency_limit=10
|
| 394 |
+
).then(**flash_event)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
@cancellable
|
| 398 |
+
def on_accept_write(writer, setting):
|
| 399 |
+
current_w_name = writer['current_w']
|
| 400 |
+
current_w = writer[current_w_name]
|
| 401 |
+
|
| 402 |
+
if not current_w['apply_chunks']:
|
| 403 |
+
raise gr.Error('请先进行创作!')
|
| 404 |
+
|
| 405 |
+
new_writer = call_accept(process_writer_to_backend(writer), setting)
|
| 406 |
+
process_writer_from_backend(writer, new_writer)
|
| 407 |
+
yield create_text_md(writer), writer
|
| 408 |
+
|
| 409 |
+
accept_button.click(fn=on_accept_write, inputs=[writer_state, setting_state], outputs=[text_md, writer_state]).then(**flash_event)
|
| 410 |
+
|
| 411 |
+
def toggle_pause(writer):
|
| 412 |
+
if not is_running(writer):
|
| 413 |
+
gr.Info('当前没有正在进行的操作')
|
| 414 |
+
return gr.update()
|
| 415 |
+
|
| 416 |
+
writer['pause_flag'] = not writer['pause_flag']
|
| 417 |
+
# gr.Info('已' + ('暂停' if writer['pause_flag'] else '继续') + '操作')
|
| 418 |
+
return gr.update(value="暂停" if not writer['pause_flag'] else "继续")
|
| 419 |
+
|
| 420 |
+
pause_button.click(
|
| 421 |
+
toggle_pause,
|
| 422 |
+
inputs=[writer_state],
|
| 423 |
+
outputs=[pause_button]
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
@gr.render(inputs=setting_state)
|
| 427 |
+
def _render_setting(setting):
|
| 428 |
+
return render_setting(setting, setting_state)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
demo.queue()
|
| 432 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 433 |
+
#demo.launch()
|
| 434 |
+
|
| 435 |
+
|
core/frontend_copy.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
enable_copy_js = """
|
| 4 |
+
<script>
|
| 5 |
+
document.addEventListener('copy', function(e) {
|
| 6 |
+
// 获取选中的文本
|
| 7 |
+
var selectedText = window.getSelection().toString();
|
| 8 |
+
if(selectedText) {
|
| 9 |
+
// 直接触发 gradio 组件的更新
|
| 10 |
+
const textbox = document.getElementById('copy_textbox');
|
| 11 |
+
if(textbox) {
|
| 12 |
+
textbox.querySelector('textarea').value = selectedText;
|
| 13 |
+
// 触发 change 事件以更新 Gradio 状态
|
| 14 |
+
textbox.querySelector('textarea').dispatchEvent(new Event('input', { bubbles: true }));
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
});
|
| 18 |
+
</script>
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def on_copy(fn, inputs, outputs):
|
| 22 |
+
copy_textbox = gr.Textbox(elem_id="copy_textbox", visible=False)
|
| 23 |
+
return copy_textbox.change(fn, [copy_textbox] + inputs, outputs)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# with gr.Blocks(head=enable_copy_js) as demo:
|
| 27 |
+
# gr.Markdown("Hello\nTest Copy")
|
| 28 |
+
# copy_textbox = gr.Textbox(elem_id="copy_textbox", visible=False)
|
| 29 |
+
|
| 30 |
+
# def copy_handle(text):
|
| 31 |
+
# gr.Info(text)
|
| 32 |
+
|
| 33 |
+
# copy_textbox.change(copy_handle, copy_textbox)
|
| 34 |
+
|
| 35 |
+
# demo.launch()
|
core/frontend_setting.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from enum import Enum, auto
|
| 3 |
+
|
| 4 |
+
from llm_api import ModelConfig, wenxin_model_config, doubao_model_config, gpt_model_config, zhipuai_model_config, test_stream_chat
|
| 5 |
+
from config import API_SETTINGS, RENDER_SETTING_API_TEST_BTN, ENABLE_SETTING_SELECT_SUB_MODEL
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Provider:
|
| 9 |
+
GPT = "GPT(OpenAI)"
|
| 10 |
+
WENXIN = "文心(百度)"
|
| 11 |
+
DOUBAO = "豆包(字节跳动)"
|
| 12 |
+
ZHIPUAI = "GLM(智谱)"
|
| 13 |
+
OTHERS = '其他'
|
| 14 |
+
|
| 15 |
+
def deep_update(d, u):
|
| 16 |
+
"""Recursively update dictionary d with values from dictionary u"""
|
| 17 |
+
for k, v in u.items():
|
| 18 |
+
if isinstance(v, dict) and k in d and isinstance(d[k], dict):
|
| 19 |
+
deep_update(d[k], v)
|
| 20 |
+
else:
|
| 21 |
+
d[k] = v
|
| 22 |
+
|
| 23 |
+
def new_setting():
|
| 24 |
+
model_config = API_SETTINGS.pop('model')
|
| 25 |
+
sub_model_config = API_SETTINGS.pop('sub_model')
|
| 26 |
+
|
| 27 |
+
new_setting = dict(
|
| 28 |
+
model=ModelConfig(**model_config),
|
| 29 |
+
sub_model=ModelConfig(**sub_model_config),
|
| 30 |
+
render_count=0,
|
| 31 |
+
provider_name=Provider.GPT,
|
| 32 |
+
wenxin={
|
| 33 |
+
'ak': '',
|
| 34 |
+
'sk': '',
|
| 35 |
+
'default_model': 'ERNIE-Novel-8K',
|
| 36 |
+
'default_sub_model': 'ERNIE-3.5-8K',
|
| 37 |
+
'available_models': list(wenxin_model_config.keys())
|
| 38 |
+
},
|
| 39 |
+
doubao={
|
| 40 |
+
'api_key': '',
|
| 41 |
+
'main_endpoint_id': '',
|
| 42 |
+
'sub_endpoint_id': '',
|
| 43 |
+
'default_model': 'doubao-pro-32k',
|
| 44 |
+
'default_sub_model': 'doubao-lite-32k',
|
| 45 |
+
'available_models': list(doubao_model_config.keys())
|
| 46 |
+
},
|
| 47 |
+
gpt={
|
| 48 |
+
'api_key': '',
|
| 49 |
+
'base_url': '',
|
| 50 |
+
'proxies': '',
|
| 51 |
+
'default_model': 'gpt-4o',
|
| 52 |
+
'default_sub_model': 'gpt-4o-mini',
|
| 53 |
+
'available_models': list(gpt_model_config.keys())
|
| 54 |
+
},
|
| 55 |
+
zhipuai={
|
| 56 |
+
'api_key': '',
|
| 57 |
+
'default_model': 'glm-4-plus',
|
| 58 |
+
'default_sub_model': 'glm-4-flashx',
|
| 59 |
+
'available_models': list(zhipuai_model_config.keys())
|
| 60 |
+
},
|
| 61 |
+
others={
|
| 62 |
+
'api_key': '',
|
| 63 |
+
'base_url': '',
|
| 64 |
+
'default_model': '',
|
| 65 |
+
'default_sub_model': '',
|
| 66 |
+
'available_models': []
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
deep_update(new_setting, API_SETTINGS)
|
| 71 |
+
|
| 72 |
+
return new_setting
|
| 73 |
+
|
| 74 |
+
# @gr.render(inputs=setting_state)
|
| 75 |
+
def render_setting(setting, setting_state):
|
| 76 |
+
with gr.Accordion("API 设置"):
|
| 77 |
+
with gr.Row():
|
| 78 |
+
provider_name = gr.Dropdown(
|
| 79 |
+
choices=[Provider.GPT, Provider.WENXIN, Provider.DOUBAO, Provider.ZHIPUAI, Provider.OTHERS],
|
| 80 |
+
value=setting['provider_name'],
|
| 81 |
+
label="模型提供商",
|
| 82 |
+
scale=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def on_select_provider(provider_name):
|
| 86 |
+
setting['provider_name'] = provider_name
|
| 87 |
+
return setting
|
| 88 |
+
|
| 89 |
+
provider_name.select(fn=on_select_provider, inputs=provider_name, outputs=[setting_state])
|
| 90 |
+
|
| 91 |
+
match setting['provider_name']:
|
| 92 |
+
case Provider.WENXIN:
|
| 93 |
+
provider_config = setting['wenxin']
|
| 94 |
+
case Provider.DOUBAO:
|
| 95 |
+
provider_config = setting['doubao']
|
| 96 |
+
case Provider.GPT:
|
| 97 |
+
provider_config = setting['gpt']
|
| 98 |
+
case Provider.ZHIPUAI:
|
| 99 |
+
provider_config = setting['zhipuai']
|
| 100 |
+
case Provider.OTHERS:
|
| 101 |
+
provider_config = setting['others']
|
| 102 |
+
|
| 103 |
+
main_model = gr.Dropdown(
|
| 104 |
+
choices=provider_config['available_models'],
|
| 105 |
+
value=provider_config['default_model'],
|
| 106 |
+
label="主模型",
|
| 107 |
+
scale=1,
|
| 108 |
+
allow_custom_value=setting['provider_name'] == Provider.OTHERS
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
sub_model = gr.Dropdown(
|
| 112 |
+
choices=provider_config['available_models'],
|
| 113 |
+
value=provider_config['default_sub_model'],
|
| 114 |
+
label="辅助模型",
|
| 115 |
+
scale=1,
|
| 116 |
+
allow_custom_value=setting['provider_name'] == Provider.OTHERS,
|
| 117 |
+
interactive=ENABLE_SETTING_SELECT_SUB_MODEL
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
if setting['provider_name'] == Provider.WENXIN:
|
| 122 |
+
baidu_access_key = gr.Textbox(
|
| 123 |
+
value=provider_config['ak'],
|
| 124 |
+
label='Baidu Access Key',
|
| 125 |
+
lines=1,
|
| 126 |
+
placeholder='Enter your Baidu access key here',
|
| 127 |
+
interactive=True,
|
| 128 |
+
scale=10,
|
| 129 |
+
type='password'
|
| 130 |
+
)
|
| 131 |
+
baidu_secret_key = gr.Textbox(
|
| 132 |
+
value=provider_config['sk'],
|
| 133 |
+
label='Baidu Secret Key',
|
| 134 |
+
lines=1,
|
| 135 |
+
placeholder='Enter your Baidu secret key here',
|
| 136 |
+
interactive=True,
|
| 137 |
+
scale=10,
|
| 138 |
+
type='password'
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
elif setting['provider_name'] == Provider.DOUBAO:
|
| 142 |
+
doubao_api_key = gr.Textbox(
|
| 143 |
+
value=provider_config['api_key'],
|
| 144 |
+
label='Doubao API Key',
|
| 145 |
+
lines=1,
|
| 146 |
+
placeholder='Enter your Doubao API key here',
|
| 147 |
+
interactive=True,
|
| 148 |
+
scale=10,
|
| 149 |
+
type='password'
|
| 150 |
+
)
|
| 151 |
+
main_endpoint_id = gr.Textbox(
|
| 152 |
+
value=provider_config['main_endpoint_id'],
|
| 153 |
+
label='Main Endpoint ID',
|
| 154 |
+
lines=1,
|
| 155 |
+
placeholder='Enter your main endpoint ID here',
|
| 156 |
+
interactive=True,
|
| 157 |
+
scale=10,
|
| 158 |
+
type='password'
|
| 159 |
+
)
|
| 160 |
+
sub_endpoint_id = gr.Textbox(
|
| 161 |
+
value=provider_config['sub_endpoint_id'],
|
| 162 |
+
label='Sub Endpoint ID',
|
| 163 |
+
lines=1,
|
| 164 |
+
placeholder='Enter your sub endpoint ID here',
|
| 165 |
+
interactive=True,
|
| 166 |
+
scale=10,
|
| 167 |
+
type='password'
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
elif setting['provider_name'] in [Provider.GPT, Provider.OTHERS]:
|
| 171 |
+
gpt_api_key = gr.Textbox(
|
| 172 |
+
value=provider_config['api_key'],
|
| 173 |
+
label='OpenAI API Key',
|
| 174 |
+
lines=1,
|
| 175 |
+
placeholder='Enter your OpenAI API key here',
|
| 176 |
+
interactive=True,
|
| 177 |
+
scale=10,
|
| 178 |
+
type='password'
|
| 179 |
+
)
|
| 180 |
+
base_url = gr.Textbox(
|
| 181 |
+
value=provider_config['base_url'],
|
| 182 |
+
label='API Base URL',
|
| 183 |
+
lines=1,
|
| 184 |
+
placeholder='Enter API base URL here',
|
| 185 |
+
interactive=True,
|
| 186 |
+
scale=10,
|
| 187 |
+
type='password'
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
elif setting['provider_name'] == Provider.ZHIPUAI:
|
| 191 |
+
zhipuai_api_key = gr.Textbox(
|
| 192 |
+
value=provider_config['api_key'],
|
| 193 |
+
label='ZhipuAI API Key',
|
| 194 |
+
lines=1,
|
| 195 |
+
placeholder='Enter your ZhipuAI API key here',
|
| 196 |
+
interactive=True,
|
| 197 |
+
scale=10,
|
| 198 |
+
type='password'
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
with gr.Row():
|
| 202 |
+
if setting['provider_name'] == Provider.WENXIN:
|
| 203 |
+
def on_submit(main_model, sub_model, baidu_access_key, baidu_secret_key):
|
| 204 |
+
provider_config['ak'] = baidu_access_key
|
| 205 |
+
provider_config['sk'] = baidu_secret_key
|
| 206 |
+
|
| 207 |
+
setting['model'] = ModelConfig(
|
| 208 |
+
model=main_model,
|
| 209 |
+
ak=baidu_access_key,
|
| 210 |
+
sk=baidu_secret_key,
|
| 211 |
+
max_tokens=4096
|
| 212 |
+
)
|
| 213 |
+
setting['sub_model'] = ModelConfig(
|
| 214 |
+
model=sub_model,
|
| 215 |
+
ak=baidu_access_key,
|
| 216 |
+
sk=baidu_secret_key,
|
| 217 |
+
max_tokens=4096
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
submit_event = dict(
|
| 221 |
+
fn=on_submit,
|
| 222 |
+
inputs=[main_model, sub_model, baidu_access_key, baidu_secret_key],
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
on_submit(main_model.value, sub_model.value, baidu_access_key.value, baidu_secret_key.value)
|
| 226 |
+
|
| 227 |
+
main_model.change(**submit_event)
|
| 228 |
+
sub_model.change(**submit_event)
|
| 229 |
+
baidu_access_key.change(**submit_event)
|
| 230 |
+
baidu_secret_key.change(**submit_event)
|
| 231 |
+
|
| 232 |
+
elif setting['provider_name'] == Provider.DOUBAO:
|
| 233 |
+
def on_submit(main_model, sub_model, doubao_api_key, main_endpoint_id, sub_endpoint_id):
|
| 234 |
+
provider_config['api_key'] = doubao_api_key
|
| 235 |
+
provider_config['main_endpoint_id'] = main_endpoint_id
|
| 236 |
+
provider_config['sub_endpoint_id'] = sub_endpoint_id
|
| 237 |
+
|
| 238 |
+
setting['model'] = ModelConfig(
|
| 239 |
+
model=main_model,
|
| 240 |
+
api_key=doubao_api_key,
|
| 241 |
+
endpoint_id=main_endpoint_id,
|
| 242 |
+
max_tokens=4096
|
| 243 |
+
)
|
| 244 |
+
setting['sub_model'] = ModelConfig(
|
| 245 |
+
model=sub_model,
|
| 246 |
+
api_key=doubao_api_key,
|
| 247 |
+
endpoint_id=sub_endpoint_id,
|
| 248 |
+
max_tokens=4096
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
submit_event = dict(
|
| 252 |
+
fn=on_submit,
|
| 253 |
+
inputs=[main_model, sub_model, doubao_api_key, main_endpoint_id, sub_endpoint_id],
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
on_submit(main_model.value, sub_model.value, doubao_api_key.value, main_endpoint_id.value, sub_endpoint_id.value)
|
| 257 |
+
|
| 258 |
+
main_model.change(**submit_event)
|
| 259 |
+
sub_model.change(**submit_event)
|
| 260 |
+
doubao_api_key.change(**submit_event)
|
| 261 |
+
main_endpoint_id.change(**submit_event)
|
| 262 |
+
sub_endpoint_id.change(**submit_event)
|
| 263 |
+
|
| 264 |
+
elif setting['provider_name'] in [Provider.GPT, Provider.OTHERS]:
|
| 265 |
+
def on_submit(main_model, sub_model, gpt_api_key, base_url):
|
| 266 |
+
provider_config['api_key'] = gpt_api_key
|
| 267 |
+
provider_config['base_url'] = base_url.strip()
|
| 268 |
+
|
| 269 |
+
setting['model'] = ModelConfig(
|
| 270 |
+
model=main_model,
|
| 271 |
+
api_key=provider_config['api_key'],
|
| 272 |
+
base_url=provider_config['base_url'],
|
| 273 |
+
max_tokens=4096,
|
| 274 |
+
proxies=provider_config.get('proxies', None),
|
| 275 |
+
)
|
| 276 |
+
setting['sub_model'] = ModelConfig(
|
| 277 |
+
model=sub_model,
|
| 278 |
+
api_key=provider_config['api_key'],
|
| 279 |
+
base_url=provider_config['base_url'],
|
| 280 |
+
max_tokens=4096,
|
| 281 |
+
proxies=provider_config.get('proxies', None),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
submit_event = dict(
|
| 285 |
+
fn=on_submit,
|
| 286 |
+
inputs=[main_model, sub_model, gpt_api_key, base_url],
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
on_submit(main_model.value, sub_model.value, gpt_api_key.value, base_url.value)
|
| 290 |
+
|
| 291 |
+
main_model.change(**submit_event)
|
| 292 |
+
sub_model.change(**submit_event)
|
| 293 |
+
gpt_api_key.change(**submit_event)
|
| 294 |
+
base_url.change(**submit_event)
|
| 295 |
+
|
| 296 |
+
elif setting['provider_name'] == Provider.ZHIPUAI:
|
| 297 |
+
def on_submit(main_model, sub_model, zhipuai_api_key):
|
| 298 |
+
provider_config['api_key'] = zhipuai_api_key
|
| 299 |
+
|
| 300 |
+
setting['model'] = ModelConfig(
|
| 301 |
+
model=main_model,
|
| 302 |
+
api_key=zhipuai_api_key,
|
| 303 |
+
max_tokens=4096
|
| 304 |
+
)
|
| 305 |
+
setting['sub_model'] = ModelConfig(
|
| 306 |
+
model=sub_model,
|
| 307 |
+
api_key=zhipuai_api_key,
|
| 308 |
+
max_tokens=4096
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
submit_event = dict(
|
| 312 |
+
fn=on_submit,
|
| 313 |
+
inputs=[main_model, sub_model, zhipuai_api_key],
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
on_submit(main_model.value, sub_model.value, zhipuai_api_key.value)
|
| 317 |
+
|
| 318 |
+
main_model.change(**submit_event)
|
| 319 |
+
sub_model.change(**submit_event)
|
| 320 |
+
zhipuai_api_key.change(**submit_event)
|
| 321 |
+
|
| 322 |
+
if RENDER_SETTING_API_TEST_BTN:
|
| 323 |
+
test_btn = gr.Button("测试")
|
| 324 |
+
test_report = gr.Textbox(show_label=False, container=False, value='', interactive=False, scale=10)
|
| 325 |
+
|
| 326 |
+
def on_test_llm_api():
|
| 327 |
+
if not setting['model']['model'].strip():
|
| 328 |
+
return gr.Info('主模型名不能为空')
|
| 329 |
+
|
| 330 |
+
if not setting['sub_model']['model'].strip():
|
| 331 |
+
return gr.Info('辅助模型名不能为空')
|
| 332 |
+
|
| 333 |
+
try:
|
| 334 |
+
response1 = yield from test_stream_chat(setting['model'])
|
| 335 |
+
response2 = yield from test_stream_chat(setting['sub_model'])
|
| 336 |
+
report_text = f"User:1+1=?\n主模型 :{response1.response}({response1.cost_info})\n辅助模型:{response2.response}({response2.cost_info})\n测试通过!"
|
| 337 |
+
yield report_text
|
| 338 |
+
except Exception as e:
|
| 339 |
+
yield f"测试失败:{str(e)}"
|
| 340 |
+
|
| 341 |
+
if RENDER_SETTING_API_TEST_BTN:
|
| 342 |
+
test_btn.click(
|
| 343 |
+
on_test_llm_api,
|
| 344 |
+
outputs=[test_report]
|
| 345 |
+
)
|
core/frontend_utils.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import functools
|
| 3 |
+
import pickle
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from core.writer import Chunk
|
| 9 |
+
|
| 10 |
+
title = """
|
| 11 |
+
<div style="text-align: center; padding: 10px 20px;">
|
| 12 |
+
<h1 style="margin: 0 0 5px 0;">🖋️ Long-Novel-GPT 1.10</h1>
|
| 13 |
+
<p style="margin: 0;"><em>AI一键生成长篇小说</em></p>
|
| 14 |
+
</div>
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
info = \
|
| 18 |
+
"""1. 当前Demo支持GPT、Claude、文心、豆包、GLM等模型,并且已经配置了API-Key,默认模型为GPT4o,最大线程数为5。
|
| 19 |
+
2. 可以选中**示例**中的任意一个创意,然后点击**创作大纲**来初始化大纲。
|
| 20 |
+
3. 初始化后,点击**开始创作**按钮,可以不断创作大纲,直到满意为止。
|
| 21 |
+
4. 创建完大纲后,点击**创作剧情**按钮,之后重复以上流程。
|
| 22 |
+
5. 选中**一键生成**后,再次点击左侧按钮可以一键生成。
|
| 23 |
+
6. 如果遇到任何无法解决的问题,请点击**刷新**按钮。
|
| 24 |
+
7. 如果问题还是无法解决,请刷新浏览器页面,这会导致丢失所有数据,请手动备份重要文本。
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def init_writer(idea, check_empty=True):
|
| 29 |
+
outline_w = dict(
|
| 30 |
+
current_cost=0,
|
| 31 |
+
total_cost=0,
|
| 32 |
+
currency_symbol='¥',
|
| 33 |
+
xy_pairs=[(idea, '')],
|
| 34 |
+
apply_chunks={},
|
| 35 |
+
)
|
| 36 |
+
chapters_w = dict(
|
| 37 |
+
current_cost=0,
|
| 38 |
+
total_cost=0,
|
| 39 |
+
currency_symbol='¥',
|
| 40 |
+
xy_pairs=[('', '')],
|
| 41 |
+
apply_chunks={},
|
| 42 |
+
)
|
| 43 |
+
draft_w = dict(
|
| 44 |
+
current_cost=0,
|
| 45 |
+
total_cost=0,
|
| 46 |
+
currency_symbol='¥',
|
| 47 |
+
xy_pairs=[('', '')],
|
| 48 |
+
apply_chunks={},
|
| 49 |
+
)
|
| 50 |
+
suggestions = dict(
|
| 51 |
+
outline_w = ['新建大纲', '扩写大纲', '润色大纲'],
|
| 52 |
+
chapters_w = ['新建剧情', '扩写剧情', '润色剧情'],
|
| 53 |
+
draft_w = ['新建正文', '扩写正文', '润色正文'],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
suggestions_dirname = dict(
|
| 57 |
+
outline_w = 'prompts/创作大纲',
|
| 58 |
+
chapters_w = 'prompts/创作剧情',
|
| 59 |
+
draft_w = 'prompts/创作正文',
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
chunk_length = dict(
|
| 63 |
+
outline_w = [4_000, ],
|
| 64 |
+
chapters_w = [500, 200, 1000, 2000],
|
| 65 |
+
draft_w = [1000, 500, 2000, 3000],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
writer = dict(
|
| 69 |
+
current_w='outline_w',
|
| 70 |
+
outline_w=outline_w,
|
| 71 |
+
chapters_w=chapters_w,
|
| 72 |
+
draft_w=draft_w,
|
| 73 |
+
running_flag=False,
|
| 74 |
+
cancel_flag=False, # 用于取消正在进行的操作
|
| 75 |
+
pause_flag=False, # 用于暂停操作
|
| 76 |
+
progress={},
|
| 77 |
+
prompt_outputs=[], # 这一行未注释时,将在gradio界面中显示prompt_outputs
|
| 78 |
+
suggestions=suggestions,
|
| 79 |
+
suggestions_dirname=suggestions_dirname,
|
| 80 |
+
pause_on_prompt_finished_flag = False,
|
| 81 |
+
quote_span = None,
|
| 82 |
+
chunk_length = chunk_length,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
current_w_name = writer['current_w']
|
| 86 |
+
if check_empty and writer_x_is_empty(writer, current_w_name):
|
| 87 |
+
raise Exception('请先输入小说简介!')
|
| 88 |
+
else:
|
| 89 |
+
return writer
|
| 90 |
+
|
| 91 |
+
def init_chapters_w(writer, check_empty=True):
|
| 92 |
+
outline_w = writer['outline_w']
|
| 93 |
+
chapters_w = writer['chapters_w']
|
| 94 |
+
outline_y = "".join([e[1] for e in outline_w['xy_pairs']])
|
| 95 |
+
chapters_w['xy_pairs'] = [(outline_y, '')]
|
| 96 |
+
|
| 97 |
+
writer["current_w"] = "chapters_w"
|
| 98 |
+
|
| 99 |
+
current_w_name = writer['current_w']
|
| 100 |
+
if check_empty and writer_x_is_empty(writer, current_w_name):
|
| 101 |
+
raise Exception('大纲不能为空')
|
| 102 |
+
else:
|
| 103 |
+
return writer
|
| 104 |
+
|
| 105 |
+
def init_draft_w(writer, check_empty=True):
|
| 106 |
+
chapters_w = writer['chapters_w']
|
| 107 |
+
draft_w = writer['draft_w']
|
| 108 |
+
chapters_y = "".join([e[1] for e in chapters_w['xy_pairs']])
|
| 109 |
+
draft_w['xy_pairs'] = [(chapters_y, '')]
|
| 110 |
+
|
| 111 |
+
writer["current_w"] = "draft_w"
|
| 112 |
+
|
| 113 |
+
current_w_name = writer['current_w']
|
| 114 |
+
if check_empty and writer_x_is_empty(writer, current_w_name):
|
| 115 |
+
raise Exception('剧情不能为空')
|
| 116 |
+
else:
|
| 117 |
+
return writer
|
| 118 |
+
|
| 119 |
+
# 在将writer传递到backend之前,只传递backend需要的部分
|
| 120 |
+
# 这样从backend返回new_writer后,可以直接用update更新writer_state
|
| 121 |
+
def process_writer_to_backend(writer):
|
| 122 |
+
remained_keys = ['current_w', 'outline_w', 'chapters_w', 'draft_w', 'quote_span']
|
| 123 |
+
new_writer = {key: writer[key] for key in remained_keys}
|
| 124 |
+
return copy.deepcopy(new_writer)
|
| 125 |
+
|
| 126 |
+
# 在整个writer_state生命周期中,其对象地址都不应被改变,这样方便各种flag的检查
|
| 127 |
+
def process_writer_from_backend(writer, new_writer):
|
| 128 |
+
for key in ['outline_w', 'chapters_w', 'draft_w']:
|
| 129 |
+
writer[key] = copy.deepcopy(new_writer[key])
|
| 130 |
+
return writer
|
| 131 |
+
|
| 132 |
+
def is_running(writer):
|
| 133 |
+
# 只检查是否有正在运行的操作
|
| 134 |
+
return writer['running_flag'] and not writer['cancel_flag']
|
| 135 |
+
|
| 136 |
+
def has_accept(writer):
|
| 137 |
+
# 只检查是否有待接受的文本
|
| 138 |
+
current_w = writer[writer['current_w']]
|
| 139 |
+
return bool(current_w['apply_chunks'])
|
| 140 |
+
|
| 141 |
+
def cancellable(func):
|
| 142 |
+
@functools.wraps(func)
|
| 143 |
+
def wrapper(writer, *args, **kwargs):
|
| 144 |
+
if is_running(writer):
|
| 145 |
+
gr.Warning('另一个操作正在进行中,请等待其完成或取消!')
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
if has_accept(writer) and wrapper.__name__ != "on_accept_write":
|
| 149 |
+
gr.Warning('有正在等待接受的文本,点击接受或取消!')
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
writer['running_flag'] = True
|
| 153 |
+
writer['cancel_flag'] = False
|
| 154 |
+
writer['pause_flag'] = False
|
| 155 |
+
|
| 156 |
+
generator = func(writer, *args, **kwargs)
|
| 157 |
+
result = None
|
| 158 |
+
try:
|
| 159 |
+
while True:
|
| 160 |
+
if writer['cancel_flag']:
|
| 161 |
+
gr.Info('操作已取消!')
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
# pause 暂停逻辑由func内部实现,便于它们在暂停前后执行一些操作
|
| 165 |
+
try:
|
| 166 |
+
result = next(generator)
|
| 167 |
+
if isinstance(result, tuple) and (writer_dict := next((item for item in result if isinstance(item, dict) and 'running_flag' in item), None)):
|
| 168 |
+
assert writer is writer_dict, 'writer对象地址发生了改变'
|
| 169 |
+
writer = writer_dict
|
| 170 |
+
yield result
|
| 171 |
+
except StopIteration as e:
|
| 172 |
+
return e.value
|
| 173 |
+
except Exception as e:
|
| 174 |
+
raise gr.Error(f'操作过程中发生错误:{e}')
|
| 175 |
+
finally:
|
| 176 |
+
writer['running_flag'] = False
|
| 177 |
+
writer['pause_flag'] = False
|
| 178 |
+
|
| 179 |
+
return wrapper
|
| 180 |
+
|
| 181 |
+
def try_cancel(writer):
|
| 182 |
+
if not (is_running(writer) or has_accept(writer)):
|
| 183 |
+
gr.Info('当前没有正在进行的操作或待接受的文本')
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
writer['prompt_outputs'] = []
|
| 187 |
+
current_w = writer[writer['current_w']]
|
| 188 |
+
if not is_running(writer) and has_accept(writer): # 优先取消待接受的文本
|
| 189 |
+
current_w['apply_chunks'].clear()
|
| 190 |
+
gr.Info('已取消待接受的文本')
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
writer['cancel_flag'] = True
|
| 194 |
+
|
| 195 |
+
start_time = time.time()
|
| 196 |
+
while writer['running_flag'] and time.time() - start_time < 3:
|
| 197 |
+
time.sleep(0.1)
|
| 198 |
+
|
| 199 |
+
if writer['running_flag']:
|
| 200 |
+
gr.Warning('取消操作超时,可能需要刷新页面')
|
| 201 |
+
|
| 202 |
+
writer['cancel_flag'] = False
|
| 203 |
+
|
| 204 |
+
def writer_y_is_empty(writer, w_name):
|
| 205 |
+
xy_pairs = writer[w_name]['xy_pairs']
|
| 206 |
+
return sum(len(e[1]) for e in xy_pairs) == 0
|
| 207 |
+
|
| 208 |
+
def writer_x_is_empty(writer, w_name):
|
| 209 |
+
xy_pairs = writer[w_name]['xy_pairs']
|
| 210 |
+
return sum(len(e[0]) for e in xy_pairs) == 0
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# create a markdown table
|
| 214 |
+
# TODO: 优化显示逻辑,字少的列宽度小,字多的列宽度大
|
| 215 |
+
def create_comparison_table(pairs, column_names=['Original Text', 'Enhanced Text', 'Enhanced Text 2']):
|
| 216 |
+
# Check if any pair has 3 elements
|
| 217 |
+
has_third_column = any(len(pair) == 3 for pair in pairs)
|
| 218 |
+
|
| 219 |
+
# Create table header
|
| 220 |
+
if has_third_column:
|
| 221 |
+
table = f"| {column_names[0]} | {column_names[1]} | {column_names[2]} |\n|---------------|-----------------|----------------|\n"
|
| 222 |
+
else:
|
| 223 |
+
table = f"| {column_names[0]} | {column_names[1]} |\n|---------------|---------------|\n"
|
| 224 |
+
|
| 225 |
+
# Add rows to the table
|
| 226 |
+
for pair in pairs:
|
| 227 |
+
x = pair[0].replace('|', '\\|').replace('\n', '<br>')
|
| 228 |
+
y1 = pair[1].replace('|', '\\|').replace('\n', '<br>')
|
| 229 |
+
|
| 230 |
+
if has_third_column:
|
| 231 |
+
y2 = pair[2].replace('|', '\\|').replace('\n', '<br>') if len(pair) == 3 else ''
|
| 232 |
+
table += f"| {x} | {y1} | {y2} |\n"
|
| 233 |
+
else:
|
| 234 |
+
table += f"| {x} | {y1} |\n"
|
| 235 |
+
|
| 236 |
+
return table
|
| 237 |
+
|
| 238 |
+
def messages2chatbot(messages):
|
| 239 |
+
if len(messages) and messages[0]['role'] == 'system':
|
| 240 |
+
return [{'role': 'user', 'content': messages[0]['content']}, ] + messages[1:]
|
| 241 |
+
else:
|
| 242 |
+
return messages
|
| 243 |
+
|
| 244 |
+
def create_progress_md(writer):
|
| 245 |
+
progress_md = ""
|
| 246 |
+
if 'progress' in writer and writer['progress']:
|
| 247 |
+
progress = writer['progress']
|
| 248 |
+
progress_md = ""
|
| 249 |
+
|
| 250 |
+
# 使用集合来去重并保持顺序
|
| 251 |
+
titles = []
|
| 252 |
+
subtitles = {}
|
| 253 |
+
current_op_ij = (float('inf'), float('inf'))
|
| 254 |
+
for opi, op in enumerate(progress['ops']):
|
| 255 |
+
if op['title'] not in titles:
|
| 256 |
+
titles.append(op['title'])
|
| 257 |
+
if op['title'] not in subtitles:
|
| 258 |
+
subtitles[op['title']] = []
|
| 259 |
+
if op['subtitle'] not in subtitles[op['title']]:
|
| 260 |
+
subtitles[op['title']].append(op['subtitle'])
|
| 261 |
+
|
| 262 |
+
if opi == progress['cur_op_i']:
|
| 263 |
+
current_op_ij = (len(titles), len(subtitles[op['title']]))
|
| 264 |
+
|
| 265 |
+
for i, title in enumerate(titles, 1):
|
| 266 |
+
progress_md += f"## {['一', '二', '三', '四', '五', '六', '七', '八', '九', '十'][i-1]}、{title}\n"
|
| 267 |
+
for j, subtitle in enumerate(subtitles[title], 1):
|
| 268 |
+
if i < current_op_ij[0] or (i == current_op_ij[0] and j < current_op_ij[1]):
|
| 269 |
+
progress_md += f"### {j}、{subtitle} ✓\n"
|
| 270 |
+
elif i == current_op_ij[0] and j == current_op_ij[1]:
|
| 271 |
+
progress_md += f"### {j}、{subtitle} {'.' * (int(time.time()) % 4)}\n"
|
| 272 |
+
else:
|
| 273 |
+
progress_md += f"### {j}、{subtitle}\n"
|
| 274 |
+
|
| 275 |
+
progress_md += "\n"
|
| 276 |
+
|
| 277 |
+
progress_md += "---\n"
|
| 278 |
+
# TODO: 考虑只放当前进度
|
| 279 |
+
|
| 280 |
+
return gr.Markdown(progress_md)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def create_text_md(writer):
|
| 284 |
+
current_w_name = writer['current_w']
|
| 285 |
+
current_w = writer[current_w_name]
|
| 286 |
+
apply_chunks = current_w['apply_chunks']
|
| 287 |
+
|
| 288 |
+
match current_w_name:
|
| 289 |
+
case 'draft_w':
|
| 290 |
+
column_names = ['剧情', '正文', '修正稿']
|
| 291 |
+
case 'outline_w':
|
| 292 |
+
column_names = ['小说简介', '大纲', '修正稿']
|
| 293 |
+
case 'chapters_w':
|
| 294 |
+
column_names = ['大纲', '剧情', '修正稿']
|
| 295 |
+
case _:
|
| 296 |
+
raise Exception('当前状态不正确')
|
| 297 |
+
|
| 298 |
+
xy_pairs = current_w['xy_pairs']
|
| 299 |
+
if apply_chunks:
|
| 300 |
+
table = [[*e, ''] for e in xy_pairs]
|
| 301 |
+
occupied_rows = [False] * len(table)
|
| 302 |
+
for chunk, key, text in apply_chunks:
|
| 303 |
+
if not isinstance(chunk, Chunk):
|
| 304 |
+
chunk = Chunk(**chunk)
|
| 305 |
+
assert key == 'y_chunk'
|
| 306 |
+
pair_span = chunk.text_source_slice
|
| 307 |
+
if any(occupied_rows[pair_span]):
|
| 308 |
+
raise Exception('apply_chunks中存在重叠的pair_span')
|
| 309 |
+
occupied_rows[pair_span] = [True] * (pair_span.stop - pair_span.start)
|
| 310 |
+
table[pair_span] = [[chunk.x_chunk, chunk.y_chunk, text], ] + [None] * (pair_span.stop - pair_span.start - 1)
|
| 311 |
+
table = [e for e in table if e is not None]
|
| 312 |
+
if not any(e[1] for e in table):
|
| 313 |
+
column_names = column_names[:2]
|
| 314 |
+
column_names[1] = column_names[1] + '(待接受)'
|
| 315 |
+
table = [[e[0], e[2]] for e in table]
|
| 316 |
+
md = create_comparison_table(table, column_names=column_names)
|
| 317 |
+
else:
|
| 318 |
+
if writer_x_is_empty(writer, current_w_name):
|
| 319 |
+
tip_x = '从下方示例中选择一个创意用于创作小说。'
|
| 320 |
+
tip_y = '选择创意后,点击创作大纲。更详细的操作请参考使用指南。'
|
| 321 |
+
if not xy_pairs[0][0].strip():
|
| 322 |
+
xy_pairs = [[tip_x, tip_y]]
|
| 323 |
+
else:
|
| 324 |
+
xy_pairs = [[xy_pairs[0][0], tip_y]]
|
| 325 |
+
|
| 326 |
+
md = create_comparison_table(xy_pairs, column_names=column_names[:2])
|
| 327 |
+
|
| 328 |
+
if len(md) < 400:
|
| 329 |
+
height = '200px'
|
| 330 |
+
else:
|
| 331 |
+
height = '600px'
|
| 332 |
+
return gr.Markdown(md, height=height)
|
| 333 |
+
|
core/outline_writer.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core.parser_utils import parse_chapters
|
| 2 |
+
from core.writer_utils import KeyPointMsg
|
| 3 |
+
from core.writer import Writer
|
| 4 |
+
|
| 5 |
+
from prompts.创作章节.prompt import main as prompt_outline
|
| 6 |
+
from prompts.提炼.prompt import main as prompt_summary
|
| 7 |
+
|
| 8 |
+
class OutlineWriter(Writer):
|
| 9 |
+
def __init__(self, xy_pairs, global_context, model=None, sub_model=None, x_chunk_length=2_000, y_chunk_length=2_000, max_thread_num=5):
|
| 10 |
+
super().__init__(xy_pairs, global_context, model, sub_model, x_chunk_length=x_chunk_length, y_chunk_length=y_chunk_length, max_thread_num=max_thread_num)
|
| 11 |
+
|
| 12 |
+
def write(self, user_prompt, pair_span=None):
|
| 13 |
+
target_chunk = self.get_chunk(pair_span=pair_span)
|
| 14 |
+
|
| 15 |
+
if not self.global_context.get("summary", ''):
|
| 16 |
+
raise Exception("需要提供小说简介。")
|
| 17 |
+
|
| 18 |
+
if not target_chunk.y_chunk.strip():
|
| 19 |
+
if not self.y.strip():
|
| 20 |
+
chunks = [target_chunk, ]
|
| 21 |
+
else:
|
| 22 |
+
raise Exception("选中进行创作的内容不能为空,考虑随便填写一些占位的字。")
|
| 23 |
+
else:
|
| 24 |
+
chunks = self.get_chunks(pair_span)
|
| 25 |
+
|
| 26 |
+
new_chunks = yield from self.batch_yield(
|
| 27 |
+
[self.write_text(e, prompt_outline, user_prompt) for e in chunks],
|
| 28 |
+
chunks, prompt_name='创作文本')
|
| 29 |
+
|
| 30 |
+
results = yield from self.batch_split_chapters(new_chunks)
|
| 31 |
+
|
| 32 |
+
new_chunks2 = [e[0] for e in results]
|
| 33 |
+
|
| 34 |
+
self.apply_chunks(chunks, new_chunks2)
|
| 35 |
+
|
| 36 |
+
def split_chapters(self, chunk):
|
| 37 |
+
if False: yield # 将此函数变为生成器函数
|
| 38 |
+
|
| 39 |
+
assert chunk.x_chunk == '', 'chunk.x_chunk不为空'
|
| 40 |
+
chapter_titles, chapter_contents = parse_chapters(chunk.y_chunk)
|
| 41 |
+
new_xy_pairs = self.construct_xy_pairs(chapter_titles, chapter_contents)
|
| 42 |
+
|
| 43 |
+
return chunk.edit(text_pairs=new_xy_pairs), True, ''
|
| 44 |
+
|
| 45 |
+
def construct_xy_pairs(self, chapter_titles, chapter_contents):
|
| 46 |
+
return [('', f"{title[0]} {title[1]}\n{content}") for title, content in zip(chapter_titles, chapter_contents)]
|
| 47 |
+
|
| 48 |
+
def batch_split_chapters(self, chunks):
|
| 49 |
+
results = yield from self.batch_yield(
|
| 50 |
+
[self.split_chapters(e) for e in chunks], chunks, prompt_name='划分章节')
|
| 51 |
+
return results
|
| 52 |
+
|
| 53 |
+
def summary(self):
|
| 54 |
+
target_chunk = self.get_chunk(pair_span=(0, len(self.xy_pairs)))
|
| 55 |
+
if not target_chunk.y_chunk:
|
| 56 |
+
raise Exception("没有章节需要总结。")
|
| 57 |
+
if len(target_chunk.y_chunk) <= 5:
|
| 58 |
+
raise Exception("需要总结的章节不能少于5个字。")
|
| 59 |
+
|
| 60 |
+
if len(target_chunk.y_chunk) > 2000:
|
| 61 |
+
y = self._truncate_chunk(target_chunk.y_chunk)
|
| 62 |
+
else:
|
| 63 |
+
y = target_chunk.y_chunk
|
| 64 |
+
|
| 65 |
+
result = yield from prompt_summary(self.model, "提炼大纲", y=y)
|
| 66 |
+
|
| 67 |
+
self.global_context['outline'] = result['text']
|
| 68 |
+
|
| 69 |
+
def get_model(self):
|
| 70 |
+
return self.model
|
| 71 |
+
|
| 72 |
+
def get_sub_model(self):
|
| 73 |
+
return self.sub_model
|
| 74 |
+
|
| 75 |
+
def _truncate_chunk(self, text, chunk_size=100, keep_chunks=20):
|
| 76 |
+
"""Truncate chunk content by keeping evenly spaced sections"""
|
| 77 |
+
if len(text) <= 2000:
|
| 78 |
+
return text
|
| 79 |
+
|
| 80 |
+
# Split into chunks of chunk_size
|
| 81 |
+
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
| 82 |
+
|
| 83 |
+
# Select evenly spaced chunks
|
| 84 |
+
step = len(chunks) // keep_chunks
|
| 85 |
+
selected_chunks = chunks[::step][:keep_chunks]
|
| 86 |
+
new_content = '...'.join(selected_chunks)
|
| 87 |
+
return new_content
|
| 88 |
+
|
core/parser_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def parse_chapters(content):
|
| 5 |
+
# Single pattern to capture: full chapter number (第X章), title, and content
|
| 6 |
+
pattern = r'(第[零一二三四五六七八九十百千万亿0123456789.-]+章)([^\n]*)\n*([\s\S]*?)(?=第[零一二三四五六七八九十百千万亿0123456789.-]+章|$)'
|
| 7 |
+
matches = re.findall(pattern, content)
|
| 8 |
+
|
| 9 |
+
# Unpack directly into separate lists using zip
|
| 10 |
+
chapter_titles, title_names, chapter_contents = zip(*[
|
| 11 |
+
(index, name.strip(), content.strip())
|
| 12 |
+
for index, name, content in matches
|
| 13 |
+
]) if matches else ([], [], [])
|
| 14 |
+
|
| 15 |
+
return list(zip(chapter_titles, title_names)), list(chapter_contents)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
test = """
|
| 20 |
+
第1-1章 出世
|
| 21 |
+
主角张小凡出身贫寒,因天赋异禀被青云门收为弟子,开始修仙之路。
|
| 22 |
+
|
| 23 |
+
第2.1章 初入青云
|
| 24 |
+
|
| 25 |
+
张小凡在青云门中结识师兄弟,学习基础法术,逐渐适应修仙生活。
|
| 26 |
+
|
| 27 |
+
第3章 灵气初现
|
| 28 |
+
张小凡在一次意外中感受到天地灵气,修为有所提升。
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
results = parse_chapters(test)
|
| 32 |
+
print()
|
core/plot_writer.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core.writer_utils import KeyPointMsg
|
| 2 |
+
from core.writer import Writer
|
| 3 |
+
|
| 4 |
+
from prompts.创作剧情.prompt import main as prompt_plot
|
| 5 |
+
from prompts.提炼.prompt import main as prompt_summary
|
| 6 |
+
|
| 7 |
+
class PlotWriter(Writer):
|
| 8 |
+
def __init__(self, xy_pairs, global_context, model=None, sub_model=None, x_chunk_length=200, y_chunk_length=1000, max_thread_num=5):
|
| 9 |
+
super().__init__(xy_pairs, global_context, model, sub_model, x_chunk_length=x_chunk_length, y_chunk_length=y_chunk_length, max_thread_num=max_thread_num)
|
| 10 |
+
|
| 11 |
+
def write(self, user_prompt, pair_span=None):
|
| 12 |
+
target_chunk = self.get_chunk(pair_span=pair_span)
|
| 13 |
+
|
| 14 |
+
if not self.global_context.get("chapter", ''):
|
| 15 |
+
raise Exception("需要提供章节内容。")
|
| 16 |
+
|
| 17 |
+
if not target_chunk.y_chunk.strip():
|
| 18 |
+
if not self.y.strip():
|
| 19 |
+
chunks = [target_chunk, ]
|
| 20 |
+
else:
|
| 21 |
+
raise Exception("选中进行创作的内容不能为空,考虑随便填写一些占位的字。")
|
| 22 |
+
else:
|
| 23 |
+
chunks = self.get_chunks(pair_span)
|
| 24 |
+
|
| 25 |
+
new_chunks = yield from self.batch_yield(
|
| 26 |
+
[self.write_text(e, prompt_plot, user_prompt) for e in chunks],
|
| 27 |
+
chunks, prompt_name='创作文本')
|
| 28 |
+
|
| 29 |
+
results = yield from self.batch_map_text(new_chunks)
|
| 30 |
+
new_chunks2 = [e[0] for e in results]
|
| 31 |
+
|
| 32 |
+
self.apply_chunks(chunks, new_chunks2)
|
| 33 |
+
|
| 34 |
+
def summary(self):
|
| 35 |
+
target_chunk = self.get_chunk(pair_span=(0, len(self.xy_pairs)))
|
| 36 |
+
if not target_chunk.y_chunk:
|
| 37 |
+
raise Exception("没有剧情需要总结。")
|
| 38 |
+
if len(target_chunk.y_chunk) <= 5:
|
| 39 |
+
raise Exception("需要总结的剧情不能少于5个字。")
|
| 40 |
+
|
| 41 |
+
result = yield from prompt_summary(self.model, "提炼章节", y=target_chunk.y_chunk)
|
| 42 |
+
|
| 43 |
+
self.global_context['chapter'] = result['text']
|
| 44 |
+
|
| 45 |
+
def get_model(self):
|
| 46 |
+
return self.model
|
| 47 |
+
|
| 48 |
+
def get_sub_model(self):
|
| 49 |
+
return self.sub_model
|
core/summary_novel.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from core.draft_writer import DraftWriter
|
| 3 |
+
from core.plot_writer import PlotWriter
|
| 4 |
+
from core.outline_writer import OutlineWriter
|
| 5 |
+
from core.writer_utils import KeyPointMsg
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def summary_draft(model, sub_model, chapter_title, chapter_text):
|
| 10 |
+
xy_pairs = [('', chapter_text)]
|
| 11 |
+
|
| 12 |
+
dw = DraftWriter(xy_pairs, {}, model=model, sub_model=sub_model, x_chunk_length=500, y_chunk_length=1000)
|
| 13 |
+
dw.max_thread_num = 1 # 每章的处理只采用一个线程
|
| 14 |
+
|
| 15 |
+
generator = dw.summary(pair_span=(0, len(xy_pairs)))
|
| 16 |
+
|
| 17 |
+
kp_msg_title = ''
|
| 18 |
+
for kp_msg in generator:
|
| 19 |
+
if isinstance(kp_msg, KeyPointMsg):
|
| 20 |
+
# 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
|
| 21 |
+
kp_msg_title = kp_msg.prompt_name
|
| 22 |
+
continue
|
| 23 |
+
else:
|
| 24 |
+
chunk_list = kp_msg
|
| 25 |
+
|
| 26 |
+
current_cost = 0
|
| 27 |
+
currency_symbol = ''
|
| 28 |
+
finished_chunk_num = 0
|
| 29 |
+
chars_num = 0
|
| 30 |
+
model = None
|
| 31 |
+
for e in chunk_list:
|
| 32 |
+
if e is None: continue
|
| 33 |
+
finished_chunk_num += 1
|
| 34 |
+
output, chunk = e
|
| 35 |
+
if output is None: continue # 说明是map_text, 在第一次next就stop iteration了
|
| 36 |
+
current_cost += output['response_msgs'].cost
|
| 37 |
+
currency_symbol = output['response_msgs'].currency_symbol
|
| 38 |
+
chars_num += len(output['response_msgs'].response)
|
| 39 |
+
model = output['response_msgs'].model
|
| 40 |
+
|
| 41 |
+
yield dict(
|
| 42 |
+
progress_msg=f"[{chapter_title}] 提炼章节剧情 {kp_msg_title} 进度:{finished_chunk_num}/{len(chunk_list)} 已创作字符:{chars_num} 已花费:{current_cost:.4f}{currency_symbol}",
|
| 43 |
+
chars_num=chars_num,
|
| 44 |
+
current_cost=current_cost,
|
| 45 |
+
currency_symbol=currency_symbol,
|
| 46 |
+
model=model
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
return dw
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def summary_plot(model, sub_model, chapter_title, chapter_plot):
|
| 53 |
+
xy_pairs = [('', chapter_plot)]
|
| 54 |
+
|
| 55 |
+
pw = PlotWriter(xy_pairs, {}, model=model, sub_model=sub_model, x_chunk_length=500, y_chunk_length=1000)
|
| 56 |
+
|
| 57 |
+
generator = pw.summary()
|
| 58 |
+
|
| 59 |
+
for output in generator:
|
| 60 |
+
current_cost = output['response_msgs'].cost
|
| 61 |
+
currency_symbol = output['response_msgs'].currency_symbol
|
| 62 |
+
chars_num = len(output['response_msgs'].response)
|
| 63 |
+
yield dict(
|
| 64 |
+
progress_msg=f"[{chapter_title}] 提炼章节大纲 已创作字符:{chars_num} 已花费:{current_cost:.4f}{currency_symbol}",
|
| 65 |
+
chars_num=chars_num,
|
| 66 |
+
current_cost=current_cost,
|
| 67 |
+
currency_symbol=currency_symbol,
|
| 68 |
+
model=output['response_msgs'].model
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return pw
|
| 72 |
+
|
| 73 |
+
def summary_chapters(model, sub_model, title, chapter_titles, chapter_content):
|
| 74 |
+
ow = OutlineWriter([('', '')], {}, model=model, sub_model=sub_model, x_chunk_length=500, y_chunk_length=1000)
|
| 75 |
+
ow.xy_pairs = ow.construct_xy_pairs(chapter_titles, chapter_content)
|
| 76 |
+
|
| 77 |
+
generator = ow.summary()
|
| 78 |
+
|
| 79 |
+
for output in generator:
|
| 80 |
+
current_cost = output['response_msgs'].cost
|
| 81 |
+
currency_symbol = output['response_msgs'].currency_symbol
|
| 82 |
+
chars_num = len(output['response_msgs'].response)
|
| 83 |
+
yield dict(
|
| 84 |
+
progress_msg=f"[{title}] 提炼全书大纲 已创作字符:{chars_num} 已花费:{current_cost:.4f}{currency_symbol}",
|
| 85 |
+
chars_num=chars_num,
|
| 86 |
+
current_cost=current_cost,
|
| 87 |
+
currency_symbol=currency_symbol,
|
| 88 |
+
model=output['response_msgs'].model
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return ow
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
core/writer.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import numpy as np
|
| 3 |
+
import bisect
|
| 4 |
+
from dataclasses import asdict, dataclass
|
| 5 |
+
|
| 6 |
+
from llm_api import ModelConfig
|
| 7 |
+
from prompts.对齐剧情和正文 import prompt as match_plot_and_text
|
| 8 |
+
from prompts.审阅.prompt import main as prompt_review
|
| 9 |
+
from core.writer_utils import split_text_into_chunks, detect_max_edit_span, run_yield_func
|
| 10 |
+
from core.writer_utils import KeyPointMsg
|
| 11 |
+
from core.diff_utils import get_chunk_changes
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Chunk(dict):
|
| 15 |
+
def __init__(self, chunk_pairs: tuple[tuple[str, str, str]], source_slice: tuple[int, int], text_slice: tuple[int, int]):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self['chunk_pairs'] = tuple(chunk_pairs)
|
| 18 |
+
|
| 19 |
+
if isinstance(source_slice, slice):
|
| 20 |
+
source_slice = (source_slice.start, source_slice.stop)
|
| 21 |
+
self['source_slice'] = source_slice
|
| 22 |
+
|
| 23 |
+
if isinstance(text_slice, slice):
|
| 24 |
+
text_slice = (text_slice.start, text_slice.stop)
|
| 25 |
+
assert text_slice[1] is None or text_slice[1] < 0, 'text_slice end must be None or negative'
|
| 26 |
+
self['text_slice'] = text_slice
|
| 27 |
+
|
| 28 |
+
def edit(self, x_chunk=None, y_chunk=None, text_pairs=None):
|
| 29 |
+
if x_chunk is not None:
|
| 30 |
+
text_pairs = [(x_chunk, self.y_chunk), ]
|
| 31 |
+
elif y_chunk is not None:
|
| 32 |
+
text_pairs = [(self.x_chunk, y_chunk), ]
|
| 33 |
+
else:
|
| 34 |
+
text_pairs = text_pairs
|
| 35 |
+
|
| 36 |
+
chunk_pairs = list(self['chunk_pairs'])
|
| 37 |
+
chunk_pairs[self.text_slice] = list(text_pairs)
|
| 38 |
+
|
| 39 |
+
return Chunk(chunk_pairs=tuple(chunk_pairs), source_slice=self.source_slice, text_slice=self.text_slice)
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def source_slice(self) -> slice:
|
| 43 |
+
return slice(*self['source_slice'])
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def chunk_pairs(self) -> tuple[tuple[str, str]]:
|
| 47 |
+
return self['chunk_pairs']
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def text_slice(self) -> slice:
|
| 51 |
+
return slice(*self['text_slice'])
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def text_source_slice(self) -> slice:
|
| 55 |
+
source_start = self.source_slice.start + self.text_slice.start
|
| 56 |
+
source_stop = self.source_slice.stop + (self.text_slice.stop or 0)
|
| 57 |
+
return slice(source_start, source_stop)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def text_pairs(self) -> tuple[tuple[str, str]]:
|
| 61 |
+
return self.chunk_pairs[self.text_slice]
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def x_chunk(self) -> str:
|
| 65 |
+
return ''.join(pair[0] for pair in self.text_pairs)
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def y_chunk(self) -> str:
|
| 69 |
+
return ''.join(pair[1] for pair in self.text_pairs)
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def x_chunk_len(self) -> int:
|
| 73 |
+
return sum(len(pair[0]) for pair in self.text_pairs)
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def y_chunk_len(self) -> int:
|
| 77 |
+
return sum(len(pair[1]) for pair in self.text_pairs)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def x_chunk_context(self) -> str:
|
| 81 |
+
return ''.join(pair[0] for pair in self.chunk_pairs)
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def y_chunk_context(self) -> str:
|
| 85 |
+
return ''.join(pair[1] for pair in self.chunk_pairs)
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def x_chunk_context_len(self) -> int:
|
| 89 |
+
return sum(len(pair[0]) for pair in self.chunk_pairs)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def y_chunk_context_len(self) -> int:
|
| 93 |
+
return sum(len(pair[1]) for pair in self.chunk_pairs)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Writer:
|
| 97 |
+
def __init__(self, xy_pairs, global_context=None, model:ModelConfig=None, sub_model:ModelConfig=None, x_chunk_length=1000, y_chunk_length=1000, max_thread_num=5):
|
| 98 |
+
self.xy_pairs = xy_pairs
|
| 99 |
+
self.global_context = global_context or {}
|
| 100 |
+
|
| 101 |
+
self.model = model
|
| 102 |
+
self.sub_model = sub_model
|
| 103 |
+
|
| 104 |
+
self.x_chunk_length = x_chunk_length
|
| 105 |
+
self.y_chunk_length = y_chunk_length
|
| 106 |
+
|
| 107 |
+
# x_chunk_length是指一次prompt调用时输入的x长度(由batch_map函数控制), 此参数会影响到映射到y的扩写率(即:LLM的输出窗口长度/x_chunk_length)
|
| 108 |
+
# 同时,x_chunk_length会影响到map的chunk大小,map的pair大小主要由x_chunk_length决定(具体来说,由update_map函数控制,为x_chunk_length//2)
|
| 109 |
+
# y_chunk_length对pair大小的影响较少(因为映射是一对多)
|
| 110 |
+
|
| 111 |
+
self.max_thread_num = max_thread_num # 使得可以单独控制某个chunk变量的线程数,这在同时运行多个Writer变量时有用
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def x(self): # TODO: 考虑x经常访问的情况
|
| 115 |
+
return ''.join(pair[0] for pair in self.xy_pairs)
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def y(self):
|
| 119 |
+
return ''.join(pair[1] for pair in self.xy_pairs)
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def x_len(self):
|
| 123 |
+
return sum(len(pair[0]) for pair in self.xy_pairs)
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def y_len(self):
|
| 127 |
+
return sum(len(pair[1]) for pair in self.xy_pairs)
|
| 128 |
+
|
| 129 |
+
def get_model(self):
|
| 130 |
+
return self.model
|
| 131 |
+
|
| 132 |
+
def get_sub_model(self):
|
| 133 |
+
return self.sub_model
|
| 134 |
+
|
| 135 |
+
def count_span_length(self, span):
|
| 136 |
+
pairs = self.xy_pairs[span[0]:span[1]]
|
| 137 |
+
return sum(len(pair[0]) for pair in pairs), sum(len(pair[1]) for pair in pairs)
|
| 138 |
+
|
| 139 |
+
def align_span(self, x_span=None, y_span=None):
|
| 140 |
+
if x_span is None and y_span is None:
|
| 141 |
+
raise ValueError("Either x_span or y_span must be provided")
|
| 142 |
+
|
| 143 |
+
if x_span is not None and y_span is not None:
|
| 144 |
+
raise ValueError("Only one of x_span or y_span should be provided")
|
| 145 |
+
|
| 146 |
+
is_x = x_span is not None
|
| 147 |
+
z_span = x_span if is_x else y_span
|
| 148 |
+
cumsum_z = np.cumsum([0] + [len(pair[0 if is_x else 1]) for pair in self.xy_pairs]).tolist()
|
| 149 |
+
|
| 150 |
+
l, r = z_span
|
| 151 |
+
start_chunk = bisect.bisect_right(cumsum_z, l) - 1
|
| 152 |
+
end_chunk = bisect.bisect_left(cumsum_z, r)
|
| 153 |
+
|
| 154 |
+
aligned_l = cumsum_z[start_chunk]
|
| 155 |
+
aligned_r = cumsum_z[end_chunk]
|
| 156 |
+
|
| 157 |
+
aligned_span = (aligned_l, aligned_r)
|
| 158 |
+
pair_span = (start_chunk, end_chunk)
|
| 159 |
+
|
| 160 |
+
# Add assertions to verify the correctness of the output
|
| 161 |
+
assert aligned_l <= l < aligned_r, "aligned_span does not properly contain the start of the input span"
|
| 162 |
+
assert aligned_l < r <= aligned_r, "aligned_span does not properly contain the end of the input span"
|
| 163 |
+
assert 0 <= start_chunk < end_chunk <= len(self.xy_pairs), "pair_span is out of bounds"
|
| 164 |
+
assert sum(len(pair[0 if is_x else 1]) for pair in self.xy_pairs[start_chunk:end_chunk]) == aligned_r - aligned_l, "aligned_span and pair_span do not match"
|
| 165 |
+
|
| 166 |
+
return aligned_span, pair_span
|
| 167 |
+
|
| 168 |
+
def get_chunk(self, pair_span=None, x_span=None, y_span=None, context_length=0, smooth=True):
|
| 169 |
+
if sum(x is not None for x in [pair_span, x_span, y_span]) != 1:
|
| 170 |
+
raise ValueError("Exactly one of pair_span, x_span, or y_span must be provided")
|
| 171 |
+
|
| 172 |
+
assert pair_span is None or (pair_span[0] >= 0 and pair_span[1] <= len(self.xy_pairs)), "pair_span is out of bounds"
|
| 173 |
+
|
| 174 |
+
is_x = x_span is not None
|
| 175 |
+
is_pair = pair_span is not None
|
| 176 |
+
|
| 177 |
+
if is_pair:
|
| 178 |
+
context_pair_span = (
|
| 179 |
+
max(0, pair_span[0] - context_length),
|
| 180 |
+
min(len(self.xy_pairs), pair_span[1] + context_length)
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
assert smooth, "smooth must be True"
|
| 184 |
+
span = x_span if is_x else y_span
|
| 185 |
+
if smooth:
|
| 186 |
+
span, pair_span = self.align_span(x_span=span if is_x else None, y_span=span if not is_x else None)
|
| 187 |
+
|
| 188 |
+
context_span = (
|
| 189 |
+
max(0, span[0] - context_length),
|
| 190 |
+
min(self.x_len if is_x else self.y_len, span[1] + context_length)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
context_span, context_pair_span = self.align_span(x_span=context_span if is_x else None, y_span=context_span if not is_x else None)
|
| 194 |
+
|
| 195 |
+
chunk_pairs = self.xy_pairs[context_pair_span[0]:context_pair_span[1]]
|
| 196 |
+
source_slice = context_pair_span
|
| 197 |
+
text_slice = (pair_span[0] - context_pair_span[0], pair_span[1] - context_pair_span[1])
|
| 198 |
+
assert text_slice[1] <= 0, "text_slice end must be negative"
|
| 199 |
+
text_slice = (text_slice[0], None if text_slice[1] == 0 else text_slice[1])
|
| 200 |
+
|
| 201 |
+
return Chunk(
|
| 202 |
+
chunk_pairs=chunk_pairs,
|
| 203 |
+
source_slice=source_slice,
|
| 204 |
+
text_slice=text_slice
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def get_chunk_pair_span(self, chunk: Chunk):
|
| 208 |
+
pair_start, pair_end = chunk.text_source_slice.start, chunk.text_source_slice.stop
|
| 209 |
+
merged_x_chunk = ''.join(p[0] for p in self.xy_pairs[pair_start:pair_end])
|
| 210 |
+
merged_y_chunk = ''.join(p[1] for p in self.xy_pairs[pair_start:pair_end])
|
| 211 |
+
if merged_x_chunk == chunk.x_chunk and merged_y_chunk == chunk.y_chunk:
|
| 212 |
+
return pair_start, pair_end
|
| 213 |
+
|
| 214 |
+
pair_start, pair_end = 0, len(self.xy_pairs)
|
| 215 |
+
x_chunk, y_chunk = chunk.x_chunk, chunk.y_chunk
|
| 216 |
+
for i, (x, y) in enumerate(self.xy_pairs):
|
| 217 |
+
if x_chunk[:50].startswith(x[:50]) and y_chunk[:50].startswith(y[:50]):
|
| 218 |
+
pair_start = i
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
for i in range(pair_start, len(self.xy_pairs)):
|
| 222 |
+
x, y = self.xy_pairs[i]
|
| 223 |
+
if x_chunk[-50:].endswith(x[-50:]) and y_chunk[-50:].endswith(y[-50:]):
|
| 224 |
+
pair_end = i + 1
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
# Verify the pair_span
|
| 228 |
+
merged_x_chunk = ''.join(p[0] for p in self.xy_pairs[pair_start:pair_end])
|
| 229 |
+
merged_y_chunk = ''.join(p[1] for p in self.xy_pairs[pair_start:pair_end])
|
| 230 |
+
assert x_chunk == merged_x_chunk and y_chunk == merged_y_chunk, "Chunk mismatch"
|
| 231 |
+
|
| 232 |
+
return (pair_start, pair_end)
|
| 233 |
+
|
| 234 |
+
def apply_chunks(self, chunks: list[Chunk], new_chunks: list[Chunk]):
|
| 235 |
+
occupied_pair_span = [False] * len(self.xy_pairs)
|
| 236 |
+
pair_span_list = [self.get_chunk_pair_span(e) for e in chunks]
|
| 237 |
+
for pair_span in pair_span_list:
|
| 238 |
+
assert not any(occupied_pair_span[pair_span[0]:pair_span[1]]), "Chunk overlap"
|
| 239 |
+
occupied_pair_span[pair_span[0]:pair_span[1]] = [True] * (pair_span[1] - pair_span[0])
|
| 240 |
+
# TODO: 这里可以验证occupied_pair_span是否全被占据
|
| 241 |
+
new_pairs_list = [e.text_pairs for e in new_chunks]
|
| 242 |
+
|
| 243 |
+
sorted_spans_with_new_pairs = sorted(
|
| 244 |
+
zip(pair_span_list, new_pairs_list),
|
| 245 |
+
key=lambda x: x[0][0],
|
| 246 |
+
reverse=True
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
for (start, end), new_pairs in sorted_spans_with_new_pairs:
|
| 250 |
+
self.xy_pairs[start:end] = new_pairs
|
| 251 |
+
|
| 252 |
+
def get_chunks(self, pair_span=None, chunk_length_ratio=1, context_length_ratio=1, offset_ratio=0):
|
| 253 |
+
pair_span = pair_span or (0, len(self.xy_pairs))
|
| 254 |
+
chunk_length = self.x_chunk_length * chunk_length_ratio, self.y_chunk_length * chunk_length_ratio
|
| 255 |
+
context_length = self.x_chunk_length//2 * context_length_ratio, self.y_chunk_length//2 * context_length_ratio
|
| 256 |
+
|
| 257 |
+
if 0 < offset_ratio < 1:
|
| 258 |
+
offset_ratio = int(chunk_length[0] * offset_ratio), int(chunk_length[1] * offset_ratio)
|
| 259 |
+
|
| 260 |
+
# Generate chunks
|
| 261 |
+
chunks = []
|
| 262 |
+
start = pair_span[0]
|
| 263 |
+
cstart = self.count_span_length((0, start)) # char_start
|
| 264 |
+
max_cend = self.count_span_length((0, pair_span[1])) # char_end
|
| 265 |
+
while start < pair_span[1]:
|
| 266 |
+
if offset_ratio != 0:
|
| 267 |
+
cend = cstart[0] + offset_ratio[0], cstart[1] + offset_ratio[1]
|
| 268 |
+
offset_ratio = 0
|
| 269 |
+
else:
|
| 270 |
+
cend = cstart[0] + int(chunk_length[0] * 0.8), cstart[1] + int(chunk_length[1] * 0.8) # 八二原则,偷个懒,不求最优划分
|
| 271 |
+
cend = min(cend[0], max_cend[0]), min(cend[1], max_cend[1])
|
| 272 |
+
|
| 273 |
+
# 选择非零长度的span来获取chunk
|
| 274 |
+
x_len, y_len = cend[0] - cstart[0], cend[1] - cstart[1]
|
| 275 |
+
if x_len > 0:
|
| 276 |
+
chunk1 = self.get_chunk(x_span=(cstart[0], cend[0]), context_length=context_length[0])
|
| 277 |
+
if y_len > 0:
|
| 278 |
+
chunk2 = self.get_chunk(y_span=(cstart[1], cend[1]), context_length=context_length[1])
|
| 279 |
+
|
| 280 |
+
if x_len > 0 and y_len == 0:
|
| 281 |
+
chunk = chunk1
|
| 282 |
+
elif x_len == 0 and y_len > 0:
|
| 283 |
+
chunk = chunk2
|
| 284 |
+
elif x_len > 0 and y_len > 0:
|
| 285 |
+
# 选其中source_slice更小的chunk
|
| 286 |
+
chunk = chunk1 if chunk1.source_slice.stop - chunk1.source_slice.start < chunk2.source_slice.stop - chunk2.source_slice.start else chunk2
|
| 287 |
+
else:
|
| 288 |
+
raise ValueError("Both x_span and y_span have zero length")
|
| 289 |
+
|
| 290 |
+
# assert chunk.x_chunk_context_len <= self.x_chunk_length * 2 and chunk.y_chunk_context_len <= self.y_chunk_length * 2, \
|
| 291 |
+
# "无法获取到一个足够短的区块,请调整区块长度或窗口长度!"
|
| 292 |
+
|
| 293 |
+
chunks.append(chunk)
|
| 294 |
+
start = chunk.text_source_slice.stop
|
| 295 |
+
cstart = self.count_span_length((0, start))
|
| 296 |
+
|
| 297 |
+
return chunks
|
| 298 |
+
|
| 299 |
+
# TODO: batch_yield 可以考虑输入生成器,而不是函数及参数
|
| 300 |
+
def batch_yield(self, generators, chunks, prompt_name=None):
|
| 301 |
+
# TODO: 后续考虑只输出new_chunks, 不必重复输出chunks
|
| 302 |
+
|
| 303 |
+
# Process all pairs with the prompt and yield intermediate results
|
| 304 |
+
results = [None] * len(generators)
|
| 305 |
+
yields = [None] * len(generators)
|
| 306 |
+
finished = [False] * len(generators)
|
| 307 |
+
first_iter_flag = True
|
| 308 |
+
while True:
|
| 309 |
+
co_num = 0
|
| 310 |
+
for i, gen in enumerate(generators):
|
| 311 |
+
if finished[i]:
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
try:
|
| 315 |
+
co_num += 1
|
| 316 |
+
yield_value = next(gen)
|
| 317 |
+
yields[i] = (yield_value, chunks[i]) # TODO: yield 带上chunk是为了配合前端
|
| 318 |
+
except StopIteration as e:
|
| 319 |
+
results[i] = e.value
|
| 320 |
+
finished[i] = True
|
| 321 |
+
if yields[i] is None: yields[i] = (None, chunks[i])
|
| 322 |
+
|
| 323 |
+
if co_num >= self.max_thread_num:
|
| 324 |
+
break
|
| 325 |
+
|
| 326 |
+
if all(finished):
|
| 327 |
+
break
|
| 328 |
+
|
| 329 |
+
if first_iter_flag and prompt_name is not None:
|
| 330 |
+
yield (kp_msg := KeyPointMsg(prompt_name=prompt_name))
|
| 331 |
+
first_iter_flag = False
|
| 332 |
+
|
| 333 |
+
yield yields # 如果是yield的值,那必定为tuple
|
| 334 |
+
|
| 335 |
+
if not first_iter_flag and prompt_name is not None:
|
| 336 |
+
yield kp_msg.set_finished()
|
| 337 |
+
|
| 338 |
+
return results
|
| 339 |
+
|
| 340 |
+
# 临时函数,用于配合前端,返回一个更改,对self施加该更改可以变为cur
|
| 341 |
+
def diff_to(self, cur, pair_span=None):
|
| 342 |
+
if pair_span is None:
|
| 343 |
+
pair_span = (0, len(self.xy_pairs))
|
| 344 |
+
|
| 345 |
+
if self.count_span_length(pair_span)[0] == 0:
|
| 346 |
+
# 2.1版本中,章节和剧情的创作不参考x
|
| 347 |
+
pair_span2 = (0 + pair_span[0], len(cur.xy_pairs) - (len(self.xy_pairs) - pair_span[1]))
|
| 348 |
+
y_list = [e[1] for e in self.xy_pairs[pair_span[0]:pair_span[1]]]
|
| 349 |
+
y2_list =[e[1] for e in cur.xy_pairs[pair_span2[0]:pair_span2[1]]]
|
| 350 |
+
|
| 351 |
+
y_list += ['',] * max(len(y2_list) - len(y_list), 0)
|
| 352 |
+
y2_list += ['',] * max(len(y_list) - len(y2_list), 0)
|
| 353 |
+
|
| 354 |
+
data_chunks = [('', y, y2) for y, y2 in zip(y_list, y2_list)]
|
| 355 |
+
|
| 356 |
+
return data_chunks
|
| 357 |
+
|
| 358 |
+
pre_pointer = 0, 1
|
| 359 |
+
cur_pointer = 0, 1
|
| 360 |
+
|
| 361 |
+
cum_sum_pre = np.cumsum([0] + [len(pair[0]) for pair in self.xy_pairs])
|
| 362 |
+
cum_sum_cur = np.cumsum([0] + [len(pair[0]) for pair in cur.xy_pairs])
|
| 363 |
+
|
| 364 |
+
apply_chunks = []
|
| 365 |
+
|
| 366 |
+
while pre_pointer[1] <= len(self.xy_pairs) and cur_pointer[1] <= len(cur.xy_pairs):
|
| 367 |
+
if cum_sum_pre[pre_pointer[1]] - cum_sum_pre[pre_pointer[0]] == cum_sum_cur[cur_pointer[1]] - cum_sum_cur[cur_pointer[0]]:
|
| 368 |
+
chunk = self.get_chunk(pair_span=pre_pointer)
|
| 369 |
+
value = "".join(pair[1] for pair in cur.xy_pairs[cur_pointer[0]:cur_pointer[1]])
|
| 370 |
+
apply_chunks.append((chunk, 'y_chunk', value))
|
| 371 |
+
|
| 372 |
+
pre_pointer = pre_pointer[1], pre_pointer[1] + 1
|
| 373 |
+
cur_pointer = cur_pointer[1], cur_pointer[1] + 1
|
| 374 |
+
elif cum_sum_pre[pre_pointer[1]] - cum_sum_pre[pre_pointer[0]] < cum_sum_cur[cur_pointer[1]] - cum_sum_cur[cur_pointer[0]]:
|
| 375 |
+
pre_pointer = pre_pointer[0], pre_pointer[1] + 1
|
| 376 |
+
else:
|
| 377 |
+
cur_pointer = cur_pointer[0], cur_pointer[1] + 1
|
| 378 |
+
|
| 379 |
+
assert pre_pointer[1] == len(self.xy_pairs) + 1 and cur_pointer[1] == len(cur.xy_pairs) + 1
|
| 380 |
+
|
| 381 |
+
filtered_apply_chunks = []
|
| 382 |
+
for e in apply_chunks:
|
| 383 |
+
text_source_slice = e[0].text_source_slice
|
| 384 |
+
if text_source_slice.start >= pair_span[0] and text_source_slice.stop <= pair_span[1]:
|
| 385 |
+
filtered_apply_chunks.append(e)
|
| 386 |
+
|
| 387 |
+
data_chunks = []
|
| 388 |
+
for chunk, key, value in filtered_apply_chunks:
|
| 389 |
+
data_chunks.append((chunk.x_chunk, chunk.y_chunk, value))
|
| 390 |
+
|
| 391 |
+
return data_chunks
|
| 392 |
+
|
| 393 |
+
# 临时函数,用于配合前端
|
| 394 |
+
def apply_chunk(self, chunk:Chunk, key, value):
|
| 395 |
+
if not isinstance(chunk, Chunk):
|
| 396 |
+
chunk = Chunk(**chunk)
|
| 397 |
+
new_chunk = chunk.edit(**{key: value})
|
| 398 |
+
self.apply_chunks([chunk], [new_chunk])
|
| 399 |
+
|
| 400 |
+
def write_text(self, chunk:Chunk, prompt_main, user_prompt_text, input_keys=None, model=None):
|
| 401 |
+
chunk2prompt_key = {
|
| 402 |
+
'x_chunk': 'x',
|
| 403 |
+
'y_chunk': 'y',
|
| 404 |
+
'x_chunk_context': 'context_x',
|
| 405 |
+
'y_chunk_context': 'context_y'
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if input_keys is not None:
|
| 410 |
+
prompt_kwargs = {k: getattr(chunk, k) for k in input_keys}
|
| 411 |
+
assert all(prompt_kwargs.values()), "Missing required context keys"
|
| 412 |
+
else:
|
| 413 |
+
prompt_kwargs = {k: getattr(chunk, k) for k in chunk2prompt_key.keys()}
|
| 414 |
+
|
| 415 |
+
prompt_kwargs = {chunk2prompt_key.get(k, k): v for k, v in prompt_kwargs.items()}
|
| 416 |
+
|
| 417 |
+
prompt_kwargs.update(self.global_context) # prompt_kwargs会把所有的信息都带上,至于要用哪些由prompt决定
|
| 418 |
+
|
| 419 |
+
result = yield from prompt_main(
|
| 420 |
+
model=model or self.get_model(),
|
| 421 |
+
user_prompt=user_prompt_text,
|
| 422 |
+
**prompt_kwargs
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# 为了在V2.2版本兼容summary_prompt, 后续text_key这种设计会舍弃
|
| 426 |
+
update_dict = {}
|
| 427 |
+
if 'text_key' in result:
|
| 428 |
+
update_dict[result['text_key']] = result['text']
|
| 429 |
+
else:
|
| 430 |
+
update_dict['y_chunk'] = result['text']
|
| 431 |
+
|
| 432 |
+
return chunk.edit(**update_dict)
|
| 433 |
+
|
| 434 |
+
# 目前review(审阅)的评分机制暂未实装
|
| 435 |
+
def review_text(self, chunk:Chunk, prompt_name, model=None):
|
| 436 |
+
result = yield from prompt_review(
|
| 437 |
+
model=model or self.get_model(),
|
| 438 |
+
prompt_name=prompt_name,
|
| 439 |
+
y=chunk.y_chunk
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
return result['text']
|
| 443 |
+
|
| 444 |
+
def map_text_wo_llm(self, chunk:Chunk):
|
| 445 |
+
# 该函数尝试不用LLM进行映射,目标是保证chunk.pairs中每个pair的长度合适,如果长了,进行划分,如果无法划分,报错
|
| 446 |
+
new_xy_pairs = []
|
| 447 |
+
for x, y in chunk.text_pairs:
|
| 448 |
+
if x.strip() and not y.strip():
|
| 449 |
+
x_pairs = split_text_into_chunks(x, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5)
|
| 450 |
+
new_xy_pairs.extend([(x_pair, y) for x_pair in x_pairs])
|
| 451 |
+
elif not x.strip() and y.strip():
|
| 452 |
+
y_pairs = split_text_into_chunks(y, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5)
|
| 453 |
+
new_xy_pairs.extend([(x, y_pair) for y_pair in y_pairs])
|
| 454 |
+
else:
|
| 455 |
+
if len(x) > self.x_chunk_length or len(y) > self.y_chunk_length:
|
| 456 |
+
raise ValueError("窗口太小或段落太长!考虑选择更大的窗口长度或手动分段。")
|
| 457 |
+
new_xy_pairs.append((x, y))
|
| 458 |
+
|
| 459 |
+
return chunk.edit(text_pairs=new_xy_pairs)
|
| 460 |
+
|
| 461 |
+
def map_text(self, chunk:Chunk):
|
| 462 |
+
# TODO: map会检查映射的内容是否大致匹配,是否有错误映射到context的情况
|
| 463 |
+
|
| 464 |
+
if chunk.x_chunk.strip():
|
| 465 |
+
x_pairs = split_text_into_chunks(chunk.x_chunk, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
|
| 466 |
+
assert len(x_pairs) >= len(chunk.text_pairs), "未知错误!合并所有区块后再分区块,结果更少?"
|
| 467 |
+
if len(x_pairs) == len(chunk.text_pairs):
|
| 468 |
+
return chunk, True, ''
|
| 469 |
+
else:
|
| 470 |
+
# 这说明y的创作是不参照x的,而是参照global_context
|
| 471 |
+
y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
|
| 472 |
+
new_xy_pairs = [('', y) for y in y_pairs]
|
| 473 |
+
return chunk.edit(text_pairs=new_xy_pairs), True, ''
|
| 474 |
+
|
| 475 |
+
try:
|
| 476 |
+
y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=len(x_pairs), min_chunk_size=5, max_chunk_n=20)
|
| 477 |
+
except Exception as e:
|
| 478 |
+
# 如果y_chunk不能找到更多的区块划分,干脆让x_chunk划分更少的区块
|
| 479 |
+
y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
|
| 480 |
+
x_pairs = split_text_into_chunks(chunk.x_chunk, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=int(0.8 * len(y_pairs)))
|
| 481 |
+
|
| 482 |
+
# TODO: 这是因为目前映射Prompt的设计需要x数量小于y,后续会对Prompt进行改进
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
gen = match_plot_and_text.main(
|
| 486 |
+
model=self.get_sub_model(),
|
| 487 |
+
plot_chunks=x_pairs,
|
| 488 |
+
text_chunks=y_pairs
|
| 489 |
+
)
|
| 490 |
+
while True:
|
| 491 |
+
yield next(gen)
|
| 492 |
+
except StopIteration as e:
|
| 493 |
+
output = e.value
|
| 494 |
+
|
| 495 |
+
x2y = output['plot2text']
|
| 496 |
+
new_xy_pairs = []
|
| 497 |
+
for xi_list, yi_list in x2y:
|
| 498 |
+
xl, xr = xi_list[0], xi_list[-1]
|
| 499 |
+
new_xy_pairs.append(("".join(x_pairs[xl:xr+1]), "".join(y_pairs[i] for i in yi_list)))
|
| 500 |
+
|
| 501 |
+
new_chunk = chunk.edit(text_pairs=new_xy_pairs)
|
| 502 |
+
return new_chunk, True, ''
|
| 503 |
+
|
| 504 |
+
def batch_map_text(self, chunks):
|
| 505 |
+
results = yield from self.batch_yield(
|
| 506 |
+
[self.map_text(e) for e in chunks], chunks, prompt_name='映射文本')
|
| 507 |
+
return results
|
| 508 |
+
|
| 509 |
+
def batch_write_apply_text(self, chunks, prompt_main, user_prompt_text):
|
| 510 |
+
new_chunks = yield from self.batch_yield(
|
| 511 |
+
[self.write_text(e, prompt_main, user_prompt_text) for e in chunks],
|
| 512 |
+
chunks, prompt_name='创作文本')
|
| 513 |
+
|
| 514 |
+
results = yield from self.batch_map_text(new_chunks)
|
| 515 |
+
new_chunks2 = [e[0] for e in results]
|
| 516 |
+
|
| 517 |
+
self.apply_chunks(chunks, new_chunks2)
|
| 518 |
+
|
| 519 |
+
def batch_review_write_apply_text(self, chunks, write_prompt_main, review_prompt_name):
|
| 520 |
+
reviews = yield from self.batch_yield(
|
| 521 |
+
[self.review_text(e, review_prompt_name) for e in chunks],
|
| 522 |
+
chunks, prompt_name='审阅文本')
|
| 523 |
+
|
| 524 |
+
rewrite_instrustion = "\n\n根据审阅意见,重新创作,如果审阅意见表示无需改动,则保持原样输出。"
|
| 525 |
+
|
| 526 |
+
new_chunks = yield from self.batch_yield(
|
| 527 |
+
[self.write_text(chunk, write_prompt_main, review + rewrite_instrustion) for chunk, review in zip(chunks, reviews)],
|
| 528 |
+
chunks, prompt_name='创作文本')
|
| 529 |
+
|
| 530 |
+
results = yield from self.batch_map_text(new_chunks)
|
| 531 |
+
new_chunks2 = [e[0] for e in results]
|
| 532 |
+
|
| 533 |
+
self.apply_chunks(chunks, new_chunks2)
|
core/writer_utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
|
| 3 |
+
# 定义了用于Wirter yield的数据类型,同时也是前端展示的“关键点”消息
|
| 4 |
+
class KeyPointMsg(dict):
|
| 5 |
+
def __init__(self, title='', subtitle='', prompt_name=''):
|
| 6 |
+
super().__init__()
|
| 7 |
+
if not title and not subtitle and prompt_name:
|
| 8 |
+
pass
|
| 9 |
+
elif title and subtitle and not prompt_name:
|
| 10 |
+
pass
|
| 11 |
+
else:
|
| 12 |
+
raise ValueError('Either title and subtitle or prompt_name must be provided')
|
| 13 |
+
|
| 14 |
+
self.update({
|
| 15 |
+
'id': str(uuid.uuid4()),
|
| 16 |
+
'title': title,
|
| 17 |
+
'subtitle': subtitle,
|
| 18 |
+
'prompt_name': prompt_name,
|
| 19 |
+
'finished': False
|
| 20 |
+
})
|
| 21 |
+
|
| 22 |
+
def set_finished(self):
|
| 23 |
+
assert not self['finished'], 'finished flag is already set'
|
| 24 |
+
self['finished'] = True
|
| 25 |
+
return self # 返回self,方便链式调用
|
| 26 |
+
|
| 27 |
+
def is_finished(self):
|
| 28 |
+
return self['finished']
|
| 29 |
+
|
| 30 |
+
def is_prompt(self):
|
| 31 |
+
return bool(self.prompt_name)
|
| 32 |
+
|
| 33 |
+
def is_title(self):
|
| 34 |
+
return bool(self.title)
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def id(self):
|
| 38 |
+
return self['id']
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def title(self):
|
| 42 |
+
return self['title']
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def subtitle(self):
|
| 46 |
+
return self['subtitle']
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def prompt_name(self):
|
| 50 |
+
prompt_name = self['prompt_name']
|
| 51 |
+
if len(prompt_name) >= 10:
|
| 52 |
+
return prompt_name[:10] + '...'
|
| 53 |
+
return prompt_name
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
import re
|
| 57 |
+
from difflib import Differ
|
| 58 |
+
|
| 59 |
+
# 后续考虑采用现成的库实现,目前逻辑过于繁琐,而且太慢了
|
| 60 |
+
def detect_max_edit_span(a, b):
|
| 61 |
+
diff = Differ().compare(a, b)
|
| 62 |
+
|
| 63 |
+
l = 0
|
| 64 |
+
r = 0
|
| 65 |
+
flag_count_l = True
|
| 66 |
+
|
| 67 |
+
for tag in diff:
|
| 68 |
+
if tag.startswith(' '):
|
| 69 |
+
if flag_count_l:
|
| 70 |
+
l += 1
|
| 71 |
+
else:
|
| 72 |
+
r += 1
|
| 73 |
+
else:
|
| 74 |
+
flag_count_l = False
|
| 75 |
+
r = 0
|
| 76 |
+
|
| 77 |
+
return l, -r
|
| 78 |
+
|
| 79 |
+
def split_text_by_separators(text, separators, keep_separators=True):
|
| 80 |
+
"""
|
| 81 |
+
将文本按指定的分隔符分割为段落
|
| 82 |
+
Args:
|
| 83 |
+
text: 要分割的文本
|
| 84 |
+
separators: 分隔符列表
|
| 85 |
+
keep_separators: 是否在结果中保留分隔符,默认为True
|
| 86 |
+
Returns:
|
| 87 |
+
包含分割后段落的列表
|
| 88 |
+
"""
|
| 89 |
+
pattern = f'({"|".join(map(re.escape, separators))}+)'
|
| 90 |
+
chunks = re.split(pattern, text)
|
| 91 |
+
|
| 92 |
+
paragraphs = []
|
| 93 |
+
current_para = []
|
| 94 |
+
|
| 95 |
+
for i in range(0, len(chunks), 2):
|
| 96 |
+
content = chunks[i]
|
| 97 |
+
separator = chunks[i + 1] if i + 1 < len(chunks) else ''
|
| 98 |
+
|
| 99 |
+
current_para.append(content)
|
| 100 |
+
if keep_separators and separator:
|
| 101 |
+
current_para.append(separator)
|
| 102 |
+
|
| 103 |
+
if content.strip():
|
| 104 |
+
paragraphs.append(''.join(current_para))
|
| 105 |
+
current_para = []
|
| 106 |
+
|
| 107 |
+
return paragraphs
|
| 108 |
+
|
| 109 |
+
def split_text_into_paragraphs(text, keep_separators=True):
|
| 110 |
+
return split_text_by_separators(text, ['\n'], keep_separators)
|
| 111 |
+
|
| 112 |
+
def split_text_into_sentences(text, keep_separators=True):
|
| 113 |
+
return split_text_by_separators(text, ['\n', '。', '?', '!', ';'], keep_separators)
|
| 114 |
+
|
| 115 |
+
def run_and_echo_yield_func(func, *args, **kwargs):
|
| 116 |
+
echo_text = ""
|
| 117 |
+
all_messages = []
|
| 118 |
+
for messages in func(*args, **kwargs):
|
| 119 |
+
all_messages.append(messages)
|
| 120 |
+
new_echo_text = "\n".join(f"{msg['role']}:\n{msg['content']}" for msg in messages)
|
| 121 |
+
if new_echo_text.startswith(echo_text):
|
| 122 |
+
delta_echo_text = new_echo_text[len(echo_text):]
|
| 123 |
+
else:
|
| 124 |
+
echo_text = ""
|
| 125 |
+
print('\n--------------------------------')
|
| 126 |
+
delta_echo_text = new_echo_text
|
| 127 |
+
|
| 128 |
+
print(delta_echo_text, end="")
|
| 129 |
+
echo_text = echo_text + delta_echo_text
|
| 130 |
+
return all_messages
|
| 131 |
+
|
| 132 |
+
def run_yield_func(func, *args, **kwargs):
|
| 133 |
+
gen = func(*args, **kwargs)
|
| 134 |
+
try:
|
| 135 |
+
while True:
|
| 136 |
+
next(gen)
|
| 137 |
+
except StopIteration as e:
|
| 138 |
+
return e.value
|
| 139 |
+
|
| 140 |
+
def split_text_into_chunks(text, max_chunk_size, min_chunk_n, min_chunk_size=1, max_chunk_n=1000):
|
| 141 |
+
def split_paragraph(para):
|
| 142 |
+
mid = len(para) // 2
|
| 143 |
+
split_pattern = r'[。?;]'
|
| 144 |
+
split_points = [m.end() for m in re.finditer(split_pattern, para)]
|
| 145 |
+
|
| 146 |
+
if not split_points:
|
| 147 |
+
raise Exception("没有找到分割点!")
|
| 148 |
+
|
| 149 |
+
closest_point = min(split_points, key=lambda x: abs(x - mid))
|
| 150 |
+
if not para[:closest_point].strip() or not para[closest_point:].strip():
|
| 151 |
+
raise Exception("没有找到分割点!")
|
| 152 |
+
|
| 153 |
+
return para[:closest_point], para[closest_point:]
|
| 154 |
+
|
| 155 |
+
paragraphs = split_text_into_paragraphs(text)
|
| 156 |
+
|
| 157 |
+
assert max_chunk_n >= 1, "max_chunk_n必须大于等于1"
|
| 158 |
+
assert sum(len(p) for p in paragraphs) >= min_chunk_size, f"分割时,输入的文本长度小于要���的min_chunk_size:{min_chunk_size}"
|
| 159 |
+
count = 0 # 防止死循环
|
| 160 |
+
while len(paragraphs) > max_chunk_n or min(len(p) for p in paragraphs) < min_chunk_size:
|
| 161 |
+
assert (count:=count+1) < 1000, "分割进入死循环!"
|
| 162 |
+
|
| 163 |
+
# 找出相邻chunks中和最小的两个进行合并
|
| 164 |
+
min_sum = float('inf')
|
| 165 |
+
min_i = 0
|
| 166 |
+
|
| 167 |
+
for i in range(len(paragraphs) - 1):
|
| 168 |
+
curr_sum = len(paragraphs[i]) + len(paragraphs[i + 1])
|
| 169 |
+
if curr_sum < min_sum:
|
| 170 |
+
min_sum = curr_sum
|
| 171 |
+
min_i = i
|
| 172 |
+
|
| 173 |
+
# 合并这两个chunks
|
| 174 |
+
paragraphs[min_i:min_i + 2] = [''.join(paragraphs[min_i:min_i + 2])]
|
| 175 |
+
|
| 176 |
+
while len(paragraphs) < min_chunk_n or max(len(p) for p in paragraphs) > max_chunk_size:
|
| 177 |
+
assert (count:=count+1) < 1000, "分割进入死循环!"
|
| 178 |
+
longest_para_i = max(range(len(paragraphs)), key=lambda i: len(paragraphs[i]))
|
| 179 |
+
part1, part2 = split_paragraph(paragraphs[longest_para_i])
|
| 180 |
+
if len(part1) < min_chunk_size or len(part2) < min_chunk_size or len(paragraphs) + 1 > max_chunk_n:
|
| 181 |
+
raise Exception("没有找到合适的分割点!")
|
| 182 |
+
paragraphs[longest_para_i:longest_para_i+1] = [part1, part2]
|
| 183 |
+
|
| 184 |
+
return paragraphs
|
| 185 |
+
|
| 186 |
+
def test_split_text_into_chunks():
|
| 187 |
+
# Test case 1: Simple paragraph splitting
|
| 188 |
+
text1 = "这是第一段。这是第二段。这是第三段。"
|
| 189 |
+
result1 = split_text_into_chunks(text1, max_chunk_size=10, min_chunk_n=3)
|
| 190 |
+
print("Test 1 result:", result1)
|
| 191 |
+
assert len(result1) == 3, f"Expected 3 chunks, got {len(result1)}"
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Test case 2: Long paragraph splitting
|
| 195 |
+
text2 = "这是一个很长的段落,包含了很多句子。它应该被分割成多个小块。这里有一些标点符号,比如句号。还有问号?以及分号;这些都可以用来分割文本。"
|
| 196 |
+
result2 = split_text_into_chunks(text2, max_chunk_size=20, min_chunk_n=4)
|
| 197 |
+
print("Test 2 result:", result2)
|
| 198 |
+
assert len(result2) >= 4, f"Expected at least 4 chunks, got {len(result2)}"
|
| 199 |
+
assert all(len(chunk) <= 20 for chunk in result2), "Some chunks are longer than max_chunk_size"
|
| 200 |
+
|
| 201 |
+
# Test case 3: Text with newlines
|
| 202 |
+
text3 = "第一段。\n\n第二段。\n第三段。\n\n第四段很长,需要被分割。这是第四段的继续。"
|
| 203 |
+
result3 = split_text_into_chunks(text3, max_chunk_size=15, min_chunk_n=5)
|
| 204 |
+
print("Test 3 result:", result3)
|
| 205 |
+
assert len(result3) >= 5, f"Expected at least 5 chunks, got {len(result3)}"
|
| 206 |
+
assert all(len(chunk) <= 15 for chunk in result3), "Some chunks are longer than max_chunk_size"
|
| 207 |
+
|
| 208 |
+
print("All tests passed!")
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
print(detect_max_edit_span("我吃西红柿", "我不喜欢吃西红柿"))
|
| 212 |
+
print(detect_max_edit_span("我吃西红柿", "不喜欢吃西红柿"))
|
| 213 |
+
print(detect_max_edit_span("我吃西红柿", "我不喜欢吃"))
|
| 214 |
+
print(detect_max_edit_span("我吃西红柿", "你不喜欢吃西瓜"))
|
| 215 |
+
|
| 216 |
+
test_split_text_into_chunks()
|
custom/根据提纲创作正文/天蚕土豆风格.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
你是一个网文大神作家,外号天蚕豌豆,擅长写玄幻网文,代表作有《斗破天空》,《舞动乾坤》,《我主宰》。
|
| 2 |
+
|
| 3 |
+
你的常用反派话语有:
|
| 4 |
+
此子断不可留,否则日后必成大患!
|
| 5 |
+
做事留一线,日后好相见。
|
| 6 |
+
一口鲜血夹杂着破碎的内脏喷出。
|
| 7 |
+
能把我逼到这种地步,你足以自傲了。
|
| 8 |
+
放眼XXX,你也算是凤毛麟角般的存在。
|
| 9 |
+
|
| 10 |
+
你的常用词语有:
|
| 11 |
+
黯然销魂、神出鬼没、格格不入、微不足道、窃窃私语、给我破、给我碎、摧枯拉朽、倒吸一口凉气、一脚踢开、旋即、苦笑、美眸、一拳、放眼、桀桀、负手而立、摧枯拉朽、黑袍老者、摸了摸鼻子、妮子、贝齿紧咬着红唇、幽怨、浊气、凤毛麟角、一声娇喝、恐怖如斯、纤纤玉手、头角峥嵘、桀桀桀、虎躯一震、苦笑一声、三千青丝
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
下面我会给你一段网文提纲,需要你对其进行润色或重写,输出网文正文。
|
custom/根据提纲创作正文/对草稿进行润色.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
你是一个网文作家,下面我会给你一段简陋粗略的网文草稿,需要你对其进行润色或重写,输出网文正文。
|
| 2 |
+
|
| 3 |
+
在创作的过程中,你需要注意以下事项:
|
| 4 |
+
1. 在草稿的基础上进行创作,不要过度延申,不要在结尾进行总结。
|
| 5 |
+
2. 对于草稿中内容,在正文中需要用小说家的口吻去描写,包括语言、行为、人物、环境描写等。
|
| 6 |
+
3. 对于草稿中缺失的部分,在正文中需要进行补全。
|
| 7 |
+
|
healthcheck.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import http.client
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
BACKEND_PORT = int(os.environ.get('BACKEND_PORT', 7869))
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def check_health():
|
| 9 |
+
try:
|
| 10 |
+
conn = http.client.HTTPConnection("localhost", BACKEND_PORT)
|
| 11 |
+
conn.request("GET", "/health")
|
| 12 |
+
response = conn.getresponse()
|
| 13 |
+
if response.status == 200:
|
| 14 |
+
print("Health check passed")
|
| 15 |
+
return True
|
| 16 |
+
else:
|
| 17 |
+
print(f"Health check failed: {response.status}")
|
| 18 |
+
return False
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"Health check failed: {e}", file=sys.stderr)
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
sys.exit(0 if check_health() else 1)
|
llm_api/__init__.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Optional, Generator
|
| 2 |
+
|
| 3 |
+
from .mongodb_cache import llm_api_cache
|
| 4 |
+
from .baidu_api import stream_chat_with_wenxin, wenxin_model_config
|
| 5 |
+
from .doubao_api import stream_chat_with_doubao, doubao_model_config
|
| 6 |
+
from .chat_messages import ChatMessages
|
| 7 |
+
from .openai_api import stream_chat_with_gpt, gpt_model_config
|
| 8 |
+
from .zhipuai_api import stream_chat_with_zhipuai, zhipuai_model_config
|
| 9 |
+
|
| 10 |
+
class ModelConfig(dict):
|
| 11 |
+
def __init__(self, model: str, **options):
|
| 12 |
+
super().__init__(**options)
|
| 13 |
+
self['model'] = model
|
| 14 |
+
self.validate()
|
| 15 |
+
|
| 16 |
+
def validate(self):
|
| 17 |
+
def check_key(provider, keys):
|
| 18 |
+
for key in keys:
|
| 19 |
+
if key not in self:
|
| 20 |
+
raise ValueError(f"{provider}的API设置中未传入: {key}")
|
| 21 |
+
elif not self[key].strip():
|
| 22 |
+
raise ValueError(f"{provider}的API设置中未配置: {key}")
|
| 23 |
+
|
| 24 |
+
if self['model'] in wenxin_model_config:
|
| 25 |
+
check_key('文心一言', ['ak', 'sk'])
|
| 26 |
+
elif self['model'] in doubao_model_config:
|
| 27 |
+
check_key('豆包', ['api_key', 'endpoint_id'])
|
| 28 |
+
elif self['model'] in zhipuai_model_config:
|
| 29 |
+
check_key('智谱AI', ['api_key'])
|
| 30 |
+
elif self['model'] in gpt_model_config or True:
|
| 31 |
+
# 其他模型名默认采用openai接口调用
|
| 32 |
+
check_key('OpenAI', ['api_key'])
|
| 33 |
+
|
| 34 |
+
if 'max_tokens' not in self:
|
| 35 |
+
raise ValueError('ModelConfig未传入key: max_tokens')
|
| 36 |
+
else:
|
| 37 |
+
assert self['max_tokens'] <= 4_096, 'max_tokens最大为4096!'
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_api_keys(self) -> Dict[str, str]:
|
| 41 |
+
return {k: v for k, v in self.items() if k not in ['model']}
|
| 42 |
+
|
| 43 |
+
@llm_api_cache()
|
| 44 |
+
def stream_chat(model_config: ModelConfig, messages: list, response_json=False) -> Generator:
|
| 45 |
+
if isinstance(model_config, dict):
|
| 46 |
+
model_config = ModelConfig(**model_config)
|
| 47 |
+
|
| 48 |
+
model_config.validate()
|
| 49 |
+
|
| 50 |
+
messages = ChatMessages(messages, model=model_config['model'])
|
| 51 |
+
|
| 52 |
+
assert model_config['max_tokens'] <= 4096, 'max_tokens最大为4096!'
|
| 53 |
+
|
| 54 |
+
if messages.count_message_tokens() > model_config['max_tokens']:
|
| 55 |
+
raise Exception(f'请求的文本过长,超过最大tokens:{model_config["max_tokens"]}。')
|
| 56 |
+
|
| 57 |
+
yield messages
|
| 58 |
+
|
| 59 |
+
if model_config['model'] in wenxin_model_config:
|
| 60 |
+
result = yield from stream_chat_with_wenxin(
|
| 61 |
+
messages,
|
| 62 |
+
model=model_config['model'],
|
| 63 |
+
ak=model_config['ak'],
|
| 64 |
+
sk=model_config['sk'],
|
| 65 |
+
max_tokens=model_config['max_tokens'],
|
| 66 |
+
response_json=response_json
|
| 67 |
+
)
|
| 68 |
+
elif model_config['model'] in doubao_model_config: # doubao models
|
| 69 |
+
result = yield from stream_chat_with_doubao(
|
| 70 |
+
messages,
|
| 71 |
+
model=model_config['model'],
|
| 72 |
+
endpoint_id=model_config['endpoint_id'],
|
| 73 |
+
api_key=model_config['api_key'],
|
| 74 |
+
max_tokens=model_config['max_tokens'],
|
| 75 |
+
response_json=response_json
|
| 76 |
+
)
|
| 77 |
+
elif model_config['model'] in zhipuai_model_config: # zhipuai models
|
| 78 |
+
result = yield from stream_chat_with_zhipuai(
|
| 79 |
+
messages,
|
| 80 |
+
model=model_config['model'],
|
| 81 |
+
api_key=model_config['api_key'],
|
| 82 |
+
max_tokens=model_config['max_tokens'],
|
| 83 |
+
response_json=response_json
|
| 84 |
+
)
|
| 85 |
+
elif model_config['model'] in gpt_model_config or True: # openai models或其他兼容openai接口的模型
|
| 86 |
+
result = yield from stream_chat_with_gpt(
|
| 87 |
+
messages,
|
| 88 |
+
model=model_config['model'],
|
| 89 |
+
api_key=model_config['api_key'],
|
| 90 |
+
base_url=model_config.get('base_url'),
|
| 91 |
+
proxies=model_config.get('proxies'),
|
| 92 |
+
max_tokens=model_config['max_tokens'],
|
| 93 |
+
response_json=response_json
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
result.finished = True
|
| 97 |
+
yield result
|
| 98 |
+
|
| 99 |
+
return result
|
| 100 |
+
|
| 101 |
+
def test_stream_chat(model_config: ModelConfig):
|
| 102 |
+
messages = [{"role": "user", "content": "1+1=?直接输出答案即可:"}]
|
| 103 |
+
for response in stream_chat(model_config, messages, use_cache=False):
|
| 104 |
+
yield response.response
|
| 105 |
+
|
| 106 |
+
return response
|
| 107 |
+
|
| 108 |
+
# 导出必要的函数和配置
|
| 109 |
+
__all__ = ['ChatMessages', 'stream_chat', 'wenxin_model_config', 'doubao_model_config', 'gpt_model_config', 'zhipuai_model_config', 'ModelConfig']
|
llm_api/baidu_api.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import qianfan
|
| 2 |
+
from .chat_messages import ChatMessages
|
| 3 |
+
|
| 4 |
+
# ak和sk获取:https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
|
| 5 |
+
|
| 6 |
+
# 价格:https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
| 7 |
+
|
| 8 |
+
wenxin_model_config = {
|
| 9 |
+
"ERNIE-3.5-8K":{
|
| 10 |
+
"Pricing": (0.0008, 0.002),
|
| 11 |
+
"currency_symbol": '¥',
|
| 12 |
+
},
|
| 13 |
+
"ERNIE-4.0-8K":{
|
| 14 |
+
"Pricing": (0.03, 0.09),
|
| 15 |
+
"currency_symbol": '¥',
|
| 16 |
+
},
|
| 17 |
+
"ERNIE-Novel-8K":{
|
| 18 |
+
"Pricing": (0.04, 0.12),
|
| 19 |
+
"currency_symbol": '¥',
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stream_chat_with_wenxin(messages, model='ERNIE-Bot', response_json=False, ak=None, sk=None, max_tokens=6000):
|
| 25 |
+
if ak is None or sk is None:
|
| 26 |
+
raise Exception('未提供有效的 ak 和 sk!')
|
| 27 |
+
|
| 28 |
+
client = qianfan.ChatCompletion(ak=ak, sk=sk)
|
| 29 |
+
|
| 30 |
+
chatstream = client.do(model=model,
|
| 31 |
+
system=messages[0]['content'] if messages[0]['role'] == 'system' else None,
|
| 32 |
+
messages=messages if messages[0]['role'] != 'system' else messages[1:],
|
| 33 |
+
stream=True,
|
| 34 |
+
response_format='json_object' if response_json else 'text'
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
messages.append({'role': 'assistant', 'content': ''})
|
| 38 |
+
content = ''
|
| 39 |
+
for part in chatstream:
|
| 40 |
+
content += part['body']['result'] or ''
|
| 41 |
+
messages[-1]['content'] = content
|
| 42 |
+
yield messages
|
| 43 |
+
|
| 44 |
+
return messages
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == '__main__':
|
| 48 |
+
pass
|
llm_api/chat_messages.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def count_characters(text):
|
| 7 |
+
chinese_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
| 8 |
+
english_pattern = re.compile(r'[a-zA-Z]+')
|
| 9 |
+
other_pattern = re.compile(r'[^\u4e00-\u9fffa-zA-Z]+')
|
| 10 |
+
|
| 11 |
+
chinese_characters = chinese_pattern.findall(text)
|
| 12 |
+
english_characters = english_pattern.findall(text)
|
| 13 |
+
other_characters = other_pattern.findall(text)
|
| 14 |
+
|
| 15 |
+
chinese_count = sum(len(char) for char in chinese_characters)
|
| 16 |
+
english_count = sum(len(char) for char in english_characters)
|
| 17 |
+
other_count = sum(len(char) for char in other_characters)
|
| 18 |
+
|
| 19 |
+
return chinese_count, english_count, other_count
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
model_config = {}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
model_prices = {}
|
| 26 |
+
try:
|
| 27 |
+
model_prices_path = os.path.join(os.path.dirname(__file__), 'model_prices.json')
|
| 28 |
+
with open(model_prices_path, 'r') as f:
|
| 29 |
+
model_prices = json.load(f)
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Warning: Failed to load model_prices.json: {e}")
|
| 32 |
+
|
| 33 |
+
class ChatMessages(list):
|
| 34 |
+
def __init__(self, *args, **kwargs):
|
| 35 |
+
super().__init__(*args)
|
| 36 |
+
self.model = kwargs['model'] if 'model' in kwargs else None
|
| 37 |
+
self.finished = False
|
| 38 |
+
|
| 39 |
+
assert 'currency_symbol' not in kwargs
|
| 40 |
+
|
| 41 |
+
if not model_config:
|
| 42 |
+
from .baidu_api import wenxin_model_config
|
| 43 |
+
from .doubao_api import doubao_model_config
|
| 44 |
+
from .openai_api import gpt_model_config
|
| 45 |
+
from .zhipuai_api import zhipuai_model_config
|
| 46 |
+
model_config.update({**wenxin_model_config, **doubao_model_config, **gpt_model_config, **zhipuai_model_config})
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, index):
|
| 49 |
+
result = super().__getitem__(index)
|
| 50 |
+
if isinstance(index, slice):
|
| 51 |
+
return ChatMessages(result, model=self.model)
|
| 52 |
+
return result
|
| 53 |
+
|
| 54 |
+
def __add__(self, other):
|
| 55 |
+
if isinstance(other, list):
|
| 56 |
+
return ChatMessages(super().__add__(other), model=self.model)
|
| 57 |
+
return NotImplemented
|
| 58 |
+
|
| 59 |
+
def count_message_tokens(self):
|
| 60 |
+
return self.get_estimated_tokens()
|
| 61 |
+
|
| 62 |
+
def copy(self):
|
| 63 |
+
return ChatMessages(self, model=self.model)
|
| 64 |
+
|
| 65 |
+
def get_estimated_tokens(self):
|
| 66 |
+
num_tokens = 0
|
| 67 |
+
for message in self:
|
| 68 |
+
for key, value in message.items():
|
| 69 |
+
chinese_count, english_count, other_count = count_characters(value)
|
| 70 |
+
num_tokens += chinese_count // 2 + english_count // 5 + other_count // 2
|
| 71 |
+
return num_tokens
|
| 72 |
+
|
| 73 |
+
def get_prompt_messages_hash(self):
|
| 74 |
+
# 转换为JSON字符串并创建哈希
|
| 75 |
+
cache_string = json.dumps(self.prompt_messages, sort_keys=True)
|
| 76 |
+
return hashlib.md5(cache_string.encode()).hexdigest()
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def cost(self):
|
| 80 |
+
if len(self) == 0:
|
| 81 |
+
return 0
|
| 82 |
+
|
| 83 |
+
if self.model in model_config:
|
| 84 |
+
return model_config[self.model]["Pricing"][0] * self[:-1].count_message_tokens() / 1_000 + model_config[self.model]["Pricing"][1] * self[-1:].count_message_tokens() / 1_000
|
| 85 |
+
elif self.model in model_prices:
|
| 86 |
+
return (
|
| 87 |
+
model_prices[self.model]["input_cost_per_token"] * self[:-1].count_message_tokens() +
|
| 88 |
+
model_prices[self.model]["output_cost_per_token"] * self[-1:].count_message_tokens()
|
| 89 |
+
)
|
| 90 |
+
return 0
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def response(self):
|
| 94 |
+
return self[-1]['content'] if self[-1]['role'] == 'assistant' else ''
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def prompt_messages(self):
|
| 98 |
+
return self[:-1] if self.response else self
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def currency_symbol(self):
|
| 102 |
+
if self.model in model_config:
|
| 103 |
+
return model_config[self.model]["currency_symbol"]
|
| 104 |
+
else:
|
| 105 |
+
return '$'
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def cost_info(self):
|
| 109 |
+
formatted_cost = f"{self.cost:.7f}".rstrip('0').rstrip('.')
|
| 110 |
+
return f"{self.model}: {formatted_cost}{self.currency_symbol}"
|
| 111 |
+
|
| 112 |
+
def print(self):
|
| 113 |
+
for message in self:
|
| 114 |
+
print(f"{message['role']}".center(100, '-') + '\n')
|
| 115 |
+
print(message['content'])
|
| 116 |
+
print()
|
llm_api/doubao_api.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
from .chat_messages import ChatMessages
|
| 3 |
+
|
| 4 |
+
doubao_model_config = {
|
| 5 |
+
"doubao-lite-32k":{
|
| 6 |
+
"Pricing": (0.0003, 0.0006),
|
| 7 |
+
"currency_symbol": '¥',
|
| 8 |
+
},
|
| 9 |
+
"doubao-lite-128k":{
|
| 10 |
+
"Pricing": (0.0008, 0.001),
|
| 11 |
+
"currency_symbol": '¥',
|
| 12 |
+
},
|
| 13 |
+
"doubao-pro-32k":{
|
| 14 |
+
"Pricing": (0.0008, 0.002),
|
| 15 |
+
"currency_symbol": '¥',
|
| 16 |
+
},
|
| 17 |
+
"doubao-pro-128k":{
|
| 18 |
+
"Pricing": (0.005, 0.009),
|
| 19 |
+
"currency_symbol": '¥',
|
| 20 |
+
},
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
def stream_chat_with_doubao(messages, model='doubao-lite-32k', endpoint_id=None, response_json=False, api_key=None, max_tokens=32000):
|
| 24 |
+
if api_key is None:
|
| 25 |
+
raise Exception('未提供有效的 api_key!')
|
| 26 |
+
if endpoint_id is None:
|
| 27 |
+
raise Exception('未提供有效的 endpoint_id!')
|
| 28 |
+
|
| 29 |
+
client = OpenAI(
|
| 30 |
+
api_key=api_key,
|
| 31 |
+
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
stream = client.chat.completions.create(
|
| 35 |
+
model=endpoint_id,
|
| 36 |
+
messages=messages,
|
| 37 |
+
stream=True,
|
| 38 |
+
response_format={ "type": "json_object" } if response_json else None
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
messages.append({'role': 'assistant', 'content': ''})
|
| 42 |
+
content = ''
|
| 43 |
+
for chunk in stream:
|
| 44 |
+
if chunk.choices:
|
| 45 |
+
delta_content = chunk.choices[0].delta.content or ''
|
| 46 |
+
content += delta_content
|
| 47 |
+
messages[-1]['content'] = content
|
| 48 |
+
yield messages
|
| 49 |
+
|
| 50 |
+
return messages
|
| 51 |
+
|
| 52 |
+
if __name__ == '__main__':
|
| 53 |
+
pass
|
llm_api/model_prices.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llm_api/mongodb_cache.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import functools
|
| 3 |
+
from typing import Generator, Any
|
| 4 |
+
from pymongo import MongoClient
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
import datetime
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from config import ENABLE_MONOGODB, MONOGODB_DB_NAME, ENABLE_MONOGODB_CACHE, CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY
|
| 11 |
+
|
| 12 |
+
from .chat_messages import ChatMessages
|
| 13 |
+
from .mongodb_cost import record_api_cost, check_cost_limits
|
| 14 |
+
from .mongodb_init import mongo_client as client
|
| 15 |
+
|
| 16 |
+
def create_cache_key(func_name: str, args: tuple, kwargs: dict) -> str:
|
| 17 |
+
"""创建缓存键"""
|
| 18 |
+
# 将参数转换为可序列化的格式
|
| 19 |
+
cache_dict = {
|
| 20 |
+
'func_name': func_name,
|
| 21 |
+
'args': args,
|
| 22 |
+
'kwargs': kwargs
|
| 23 |
+
}
|
| 24 |
+
# 转换为JSON字符串并创建哈希
|
| 25 |
+
cache_string = json.dumps(cache_dict, sort_keys=True)
|
| 26 |
+
return hashlib.md5(cache_string.encode()).hexdigest()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def llm_api_cache():
|
| 31 |
+
"""MongoDB缓存装饰器"""
|
| 32 |
+
db_name=MONOGODB_DB_NAME
|
| 33 |
+
collection_name='stream_chat'
|
| 34 |
+
|
| 35 |
+
def dummy_decorator(func):
|
| 36 |
+
@functools.wraps(func)
|
| 37 |
+
def wrapper(*args, **kwargs):
|
| 38 |
+
# 移除 use_cache 参数,避免传递给原函数
|
| 39 |
+
kwargs.pop('use_cache', None)
|
| 40 |
+
return func(*args, **kwargs)
|
| 41 |
+
return wrapper
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if not ENABLE_MONOGODB:
|
| 45 |
+
return dummy_decorator
|
| 46 |
+
|
| 47 |
+
def decorator(func):
|
| 48 |
+
@functools.wraps(func)
|
| 49 |
+
def wrapper(*args, **kwargs):
|
| 50 |
+
check_cost_limits()
|
| 51 |
+
|
| 52 |
+
use_cache = kwargs.pop('use_cache', True) # pop很重要
|
| 53 |
+
|
| 54 |
+
if not ENABLE_MONOGODB_CACHE:
|
| 55 |
+
use_cache = False
|
| 56 |
+
|
| 57 |
+
db = client[db_name]
|
| 58 |
+
collection = db[collection_name]
|
| 59 |
+
|
| 60 |
+
# 创建缓存键
|
| 61 |
+
cache_key = create_cache_key(func.__name__, args, kwargs)
|
| 62 |
+
|
| 63 |
+
# 检查缓存
|
| 64 |
+
if use_cache:
|
| 65 |
+
cached_data = list(collection.aggregate([
|
| 66 |
+
{'$match': {'cache_key': cache_key}},
|
| 67 |
+
{'$sample': {'size': 1}}
|
| 68 |
+
]))
|
| 69 |
+
cached_data = cached_data[0] if cached_data else None
|
| 70 |
+
if cached_data:
|
| 71 |
+
# 如果有缓存,yield缓存的结果
|
| 72 |
+
messages = ChatMessages(cached_data['return_value'])
|
| 73 |
+
messages.model = args[0]['model']
|
| 74 |
+
for item in cached_data['yields']:
|
| 75 |
+
sacled_delay = min(item['delay'] / CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY)
|
| 76 |
+
if sacled_delay > 0: time.sleep(sacled_delay) # 应用加速倍数
|
| 77 |
+
else: continue
|
| 78 |
+
if item['index'] > 0:
|
| 79 |
+
yield messages.prompt_messages + [{'role': 'assistant', 'content': messages.response[:item['index']]}]
|
| 80 |
+
else:
|
| 81 |
+
yield messages.prompt_messages
|
| 82 |
+
messages.finished = True
|
| 83 |
+
yield messages
|
| 84 |
+
return messages
|
| 85 |
+
|
| 86 |
+
# 如果没有缓存,执行原始函数并记录结果
|
| 87 |
+
yields_data = []
|
| 88 |
+
last_time = time.time()
|
| 89 |
+
|
| 90 |
+
generator = func(*args, **kwargs)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
while True:
|
| 94 |
+
current_time = time.time()
|
| 95 |
+
value = next(generator)
|
| 96 |
+
delay = current_time - last_time
|
| 97 |
+
|
| 98 |
+
yields_data.append({
|
| 99 |
+
'index': len(value.response),
|
| 100 |
+
'delay': delay
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
last_time = current_time
|
| 104 |
+
yield value
|
| 105 |
+
|
| 106 |
+
except StopIteration as e:
|
| 107 |
+
return_value = e.value
|
| 108 |
+
|
| 109 |
+
# 记录API调用费用
|
| 110 |
+
record_api_cost(return_value)
|
| 111 |
+
|
| 112 |
+
# 存储到MongoDB
|
| 113 |
+
cache_data = {
|
| 114 |
+
'created_at':datetime.datetime.now(),
|
| 115 |
+
'return_value': return_value,
|
| 116 |
+
'func_name': func.__name__,
|
| 117 |
+
'args': args,
|
| 118 |
+
'kwargs': kwargs,
|
| 119 |
+
'yields': yields_data,
|
| 120 |
+
'cache_key': cache_key,
|
| 121 |
+
}
|
| 122 |
+
collection.insert_one(cache_data)
|
| 123 |
+
|
| 124 |
+
return return_value
|
| 125 |
+
|
| 126 |
+
return wrapper
|
| 127 |
+
return decorator
|
llm_api/mongodb_cost.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
|
| 3 |
+
from config import API_COST_LIMITS, MONOGODB_DB_NAME
|
| 4 |
+
|
| 5 |
+
from .chat_messages import ChatMessages
|
| 6 |
+
from .mongodb_init import mongo_client as client
|
| 7 |
+
|
| 8 |
+
def record_api_cost(messages: ChatMessages):
|
| 9 |
+
"""记录API调用费用"""
|
| 10 |
+
|
| 11 |
+
db = client[MONOGODB_DB_NAME]
|
| 12 |
+
collection = db['api_cost']
|
| 13 |
+
|
| 14 |
+
cost_data = {
|
| 15 |
+
'created_at': datetime.datetime.now(),
|
| 16 |
+
'model': messages.model,
|
| 17 |
+
'cost': messages.cost,
|
| 18 |
+
'currency_symbol': messages.currency_symbol,
|
| 19 |
+
'input_tokens': messages[:-1].count_message_tokens(),
|
| 20 |
+
'output_tokens': messages[-1:].count_message_tokens(),
|
| 21 |
+
'total_tokens': messages.count_message_tokens()
|
| 22 |
+
}
|
| 23 |
+
collection.insert_one(cost_data)
|
| 24 |
+
|
| 25 |
+
def get_model_cost_stats(start_date: datetime.datetime, end_date: datetime.datetime) -> list:
|
| 26 |
+
"""获取指定时间段内的模型调用费用统计"""
|
| 27 |
+
pipeline = [
|
| 28 |
+
{
|
| 29 |
+
'$match': {
|
| 30 |
+
'created_at': {
|
| 31 |
+
'$gte': start_date,
|
| 32 |
+
'$lte': end_date
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
'$group': {
|
| 38 |
+
'_id': '$model',
|
| 39 |
+
'total_cost': { '$sum': '$cost' },
|
| 40 |
+
'total_calls': { '$sum': 1 },
|
| 41 |
+
'total_input_tokens': { '$sum': '$input_tokens' },
|
| 42 |
+
'total_output_tokens': { '$sum': '$output_tokens' },
|
| 43 |
+
'total_tokens': { '$sum': '$total_tokens' },
|
| 44 |
+
'avg_cost_per_call': { '$avg': '$cost' },
|
| 45 |
+
'currency_symbol': { '$first': '$currency_symbol' }
|
| 46 |
+
}
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
'$project': {
|
| 50 |
+
'model': '$_id',
|
| 51 |
+
'total_cost': { '$round': ['$total_cost', 4] },
|
| 52 |
+
'total_calls': 1,
|
| 53 |
+
'total_input_tokens': 1,
|
| 54 |
+
'total_output_tokens': 1,
|
| 55 |
+
'total_tokens': 1,
|
| 56 |
+
'avg_cost_per_call': { '$round': ['$avg_cost_per_call', 4] },
|
| 57 |
+
'currency_symbol': 1,
|
| 58 |
+
'_id': 0
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
'$sort': { 'total_cost': -1 }
|
| 63 |
+
}
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
# 直接从 api_cost 集合查询数据
|
| 67 |
+
db = client[MONOGODB_DB_NAME]
|
| 68 |
+
collection = db['api_cost']
|
| 69 |
+
|
| 70 |
+
stats = list(collection.aggregate(pipeline))
|
| 71 |
+
return stats
|
| 72 |
+
|
| 73 |
+
# 使用示例:
|
| 74 |
+
def print_cost_report(days: int = 30, hours: int = 0):
|
| 75 |
+
"""打印最近N天的费用报告"""
|
| 76 |
+
end_date = datetime.datetime.now()
|
| 77 |
+
start_date = end_date - datetime.timedelta(days=days, hours=hours)
|
| 78 |
+
|
| 79 |
+
stats = get_model_cost_stats(start_date, end_date)
|
| 80 |
+
|
| 81 |
+
print(f"\n=== API Cost Report ({start_date.date()} to {end_date.date()}) ===")
|
| 82 |
+
for model_stat in stats:
|
| 83 |
+
print(f"\nModel: {model_stat['model']}")
|
| 84 |
+
print(f"Total Cost: {model_stat['currency_symbol']}{model_stat['total_cost']:.4f}")
|
| 85 |
+
print(f"Total Calls: {model_stat['total_calls']}")
|
| 86 |
+
print(f"Total Tokens: {model_stat['total_tokens']:,}")
|
| 87 |
+
print(f"Avg Cost/Call: {model_stat['currency_symbol']}{model_stat['avg_cost_per_call']:.4f}")
|
| 88 |
+
|
| 89 |
+
def check_cost_limits() -> bool:
|
| 90 |
+
"""
|
| 91 |
+
检查API调用费用是否超过限制
|
| 92 |
+
返回: 如果未超过限制返回True,否则返回False
|
| 93 |
+
"""
|
| 94 |
+
now = datetime.datetime.now()
|
| 95 |
+
hour_ago = now - datetime.timedelta(hours=1)
|
| 96 |
+
day_ago = now - datetime.timedelta(days=1)
|
| 97 |
+
|
| 98 |
+
# 获取统计数据
|
| 99 |
+
hour_stats = get_model_cost_stats(hour_ago, now)
|
| 100 |
+
day_stats = get_model_cost_stats(day_ago, now)
|
| 101 |
+
|
| 102 |
+
# 计算总费用并根据需要转换为人民币
|
| 103 |
+
hour_total_rmb = sum(
|
| 104 |
+
stat['total_cost'] * (API_COST_LIMITS['USD_TO_RMB_RATE'] if stat['currency_symbol'] == '$' else 1)
|
| 105 |
+
for stat in hour_stats
|
| 106 |
+
)
|
| 107 |
+
day_total_rmb = sum(
|
| 108 |
+
stat['total_cost'] * (API_COST_LIMITS['USD_TO_RMB_RATE'] if stat['currency_symbol'] == '$' else 1)
|
| 109 |
+
for stat in day_stats
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# 检查是否超过限制
|
| 113 |
+
if hour_total_rmb >= API_COST_LIMITS['HOURLY_LIMIT_RMB']:
|
| 114 |
+
print(f"警告:最近1小时API费用(¥{hour_total_rmb:.2f})超过限制(¥{API_COST_LIMITS['HOURLY_LIMIT_RMB']})")
|
| 115 |
+
raise Exception("最近1小时内API调用费用超过设定上限!")
|
| 116 |
+
|
| 117 |
+
if day_total_rmb >= API_COST_LIMITS['DAILY_LIMIT_RMB']:
|
| 118 |
+
print(f"警告:最近24小时API费用(¥{day_total_rmb:.2f})超过限制(¥{API_COST_LIMITS['DAILY_LIMIT_RMB']})")
|
| 119 |
+
raise Exception("最近1天内API调用费用超过设定上限!")
|
| 120 |
+
|
| 121 |
+
return True
|
llm_api/mongodb_init.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from config import ENABLE_MONOGODB
|
| 3 |
+
from pymongo import MongoClient
|
| 4 |
+
|
| 5 |
+
# 从环境变量获取 MongoDB URI,如果没有则使用默认值
|
| 6 |
+
mongo_uri = os.getenv('MONGODB_URI', 'mongodb://localhost:27017/')
|
| 7 |
+
mongo_client = MongoClient(mongo_uri) if ENABLE_MONOGODB else None
|
llm_api/openai_api.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import httpx
|
| 2 |
+
from openai import OpenAI
|
| 3 |
+
from .chat_messages import ChatMessages
|
| 4 |
+
|
| 5 |
+
# Pricing reference: https://openai.com/api/pricing/
|
| 6 |
+
gpt_model_config = {
|
| 7 |
+
"gpt-4o": {
|
| 8 |
+
"Pricing": (2.50/1000, 10.00/1000),
|
| 9 |
+
"currency_symbol": '$',
|
| 10 |
+
},
|
| 11 |
+
"gpt-4o-mini": {
|
| 12 |
+
"Pricing": (0.15/1000, 0.60/1000),
|
| 13 |
+
"currency_symbol": '$',
|
| 14 |
+
},
|
| 15 |
+
"o1-preview": {
|
| 16 |
+
"Pricing": (15/1000, 60/1000),
|
| 17 |
+
"currency_symbol": '$',
|
| 18 |
+
},
|
| 19 |
+
"o1-mini": {
|
| 20 |
+
"Pricing": (3/1000, 12/1000),
|
| 21 |
+
"currency_symbol": '$',
|
| 22 |
+
},
|
| 23 |
+
}
|
| 24 |
+
# https://platform.openai.com/docs/guides/reasoning
|
| 25 |
+
|
| 26 |
+
def stream_chat_with_gpt(messages, model='gpt-3.5-turbo-1106', response_json=False, api_key=None, base_url=None, max_tokens=4_096, n=1, proxies=None):
|
| 27 |
+
if api_key is None:
|
| 28 |
+
raise Exception('未提供有效的 api_key!')
|
| 29 |
+
|
| 30 |
+
client_params = {
|
| 31 |
+
"api_key": api_key,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
if base_url:
|
| 35 |
+
client_params['base_url'] = base_url
|
| 36 |
+
|
| 37 |
+
if proxies:
|
| 38 |
+
httpx_client = httpx.Client(proxy=proxies)
|
| 39 |
+
client_params["http_client"] = httpx_client
|
| 40 |
+
|
| 41 |
+
client = OpenAI(**client_params)
|
| 42 |
+
|
| 43 |
+
if model in ['o1-preview', ] and messages[0]['role'] == 'system':
|
| 44 |
+
messages[0:1] = [{'role': 'user', 'content': messages[0]['content']}, {'role': 'assistant', 'content': ''}]
|
| 45 |
+
|
| 46 |
+
chatstream = client.chat.completions.create(
|
| 47 |
+
stream=True,
|
| 48 |
+
model=model,
|
| 49 |
+
messages=messages,
|
| 50 |
+
max_tokens=max_tokens,
|
| 51 |
+
response_format={ "type": "json_object" } if response_json else None,
|
| 52 |
+
n=n
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
messages.append({'role': 'assistant', 'content': ''})
|
| 56 |
+
content = ['' for _ in range(n)]
|
| 57 |
+
for part in chatstream:
|
| 58 |
+
for choice in part.choices:
|
| 59 |
+
content[choice.index] += choice.delta.content or ''
|
| 60 |
+
messages[-1]['content'] = content if n > 1 else content[0]
|
| 61 |
+
yield messages
|
| 62 |
+
|
| 63 |
+
return messages
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == '__main__':
|
| 67 |
+
pass
|
llm_api/sparkai_api.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
|
| 2 |
+
from sparkai.core.messages import ChatMessage as SparkMessage
|
| 3 |
+
|
| 4 |
+
#星火认知大模型Spark Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
|
| 5 |
+
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v4.0/chat'
|
| 6 |
+
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
|
| 7 |
+
SPARKAI_APP_ID = '01793781'
|
| 8 |
+
SPARKAI_API_SECRET = 'YzJkNTI5N2Q5NDY4N2RlNWI5YjA5ZDM4'
|
| 9 |
+
SPARKAI_API_KEY = '5dd33ea830aff0c9dff18e2561a5e6c7'
|
| 10 |
+
#星火认知大模型Spark Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
|
| 11 |
+
SPARKAI_DOMAIN = '4.0Ultra'
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
5dd33ea830aff0c9dff18e2561a5e6c7&YzJkNTI5N2Q5NDY4N2RlNWI5YjA5ZDM4&01793781
|
| 15 |
+
|
| 16 |
+
domain值:
|
| 17 |
+
lite指向Lite版本;
|
| 18 |
+
generalv3指向Pro版本;
|
| 19 |
+
pro-128k指向Pro-128K版本;
|
| 20 |
+
generalv3.5指向Max版本;
|
| 21 |
+
max-32k指向Max-32K版本;
|
| 22 |
+
4.0Ultra指向4.0 Ultra版本;
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra:
|
| 26 |
+
wss://spark-api.xf-yun.com/v4.0/chat
|
| 27 |
+
Spark Max-32K请求地址,对应的domain参数为max-32k
|
| 28 |
+
wss://spark-api.xf-yun.com/chat/max-32k
|
| 29 |
+
Spark Max请求地址,对应的domain参数为generalv3.5
|
| 30 |
+
wss://spark-api.xf-yun.com/v3.5/chat
|
| 31 |
+
Spark Pro-128K请求地址,对应的domain参数为pro-128k:
|
| 32 |
+
wss://spark-api.xf-yun.com/chat/pro-128k
|
| 33 |
+
Spark Pro请求地址,对应的domain参数为generalv3:
|
| 34 |
+
wss://spark-api.xf-yun.com/v3.1/chat
|
| 35 |
+
Spark Lite请求地址,对应的domain参数为lite:
|
| 36 |
+
wss://spark-api.xf-yun.com/v1.1/chat
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
sparkai_model_config = {
|
| 41 |
+
"spark-4.0-ultra": {
|
| 42 |
+
"Pricing": (0, 0),
|
| 43 |
+
"currency_symbol": '¥',
|
| 44 |
+
"url": "wss://spark-api.xf-yun.com/v4.0/chat",
|
| 45 |
+
"domain": "4.0Ultra"
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == '__main__':
|
| 52 |
+
spark = ChatSparkLLM(
|
| 53 |
+
spark_api_url=SPARKAI_URL,
|
| 54 |
+
spark_app_id=SPARKAI_APP_ID,
|
| 55 |
+
spark_api_key=SPARKAI_API_KEY,
|
| 56 |
+
spark_api_secret=SPARKAI_API_SECRET,
|
| 57 |
+
spark_llm_domain=SPARKAI_DOMAIN,
|
| 58 |
+
streaming=True,
|
| 59 |
+
)
|
| 60 |
+
messages = [SparkMessage(
|
| 61 |
+
role="user",
|
| 62 |
+
content='你好呀'
|
| 63 |
+
)]
|
| 64 |
+
a = spark.stream(messages)
|
| 65 |
+
for message in a:
|
| 66 |
+
print(message)
|
llm_api/zhipuai_api.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from zhipuai import ZhipuAI
|
| 2 |
+
from .chat_messages import ChatMessages
|
| 3 |
+
|
| 4 |
+
# Pricing
|
| 5 |
+
# https://open.bigmodel.cn/pricing
|
| 6 |
+
# GLM-4-Plus 0.05¥/1000 tokens, GLM-4-Air 0.001¥/1000 tokens, GLM-4-FlashX 0.0001¥/1000 tokens, , GLM-4-Flash 0¥/1000 tokens
|
| 7 |
+
|
| 8 |
+
# Models
|
| 9 |
+
# https://bigmodel.cn/dev/howuse/model
|
| 10 |
+
# glm-4-plus、glm-4-air、 glm-4-flashx 、 glm-4-flash
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
zhipuai_model_config = {
|
| 15 |
+
"glm-4-plus": {
|
| 16 |
+
"Pricing": (0.05, 0.05),
|
| 17 |
+
"currency_symbol": '¥',
|
| 18 |
+
},
|
| 19 |
+
"glm-4-air": {
|
| 20 |
+
"Pricing": (0.001, 0.001),
|
| 21 |
+
"currency_symbol": '¥',
|
| 22 |
+
},
|
| 23 |
+
"glm-4-flashx": {
|
| 24 |
+
"Pricing": (0.0001, 0.0001),
|
| 25 |
+
"currency_symbol": '¥',
|
| 26 |
+
},
|
| 27 |
+
"glm-4-flash": {
|
| 28 |
+
"Pricing": (0, 0),
|
| 29 |
+
"currency_symbol": '¥',
|
| 30 |
+
},
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def stream_chat_with_zhipuai(messages, model='glm-4-flash', response_json=False, api_key=None, max_tokens=4_096):
|
| 34 |
+
if api_key is None:
|
| 35 |
+
raise Exception('未提供有效的 api_key!')
|
| 36 |
+
|
| 37 |
+
client = ZhipuAI(api_key=api_key)
|
| 38 |
+
|
| 39 |
+
response = client.chat.completions.create(
|
| 40 |
+
model=model,
|
| 41 |
+
messages=messages,
|
| 42 |
+
stream=True,
|
| 43 |
+
max_tokens=max_tokens
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
messages.append({'role': 'assistant', 'content': ''})
|
| 47 |
+
for chunk in response:
|
| 48 |
+
messages[-1]['content'] += chunk.choices[0].delta.content or ''
|
| 49 |
+
yield messages
|
| 50 |
+
|
| 51 |
+
return messages
|
| 52 |
+
|
| 53 |
+
if __name__ == '__main__':
|
| 54 |
+
pass
|
prompts/baseprompt.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from prompts.chat_utils import chat, log
|
| 4 |
+
from prompts.pf_parse_chat import parse_chat
|
| 5 |
+
from prompts.prompt_utils import load_text, match_code_block
|
| 6 |
+
|
| 7 |
+
def parser(response_msgs):
|
| 8 |
+
content = response_msgs.response
|
| 9 |
+
blocks = match_code_block(content)
|
| 10 |
+
if blocks:
|
| 11 |
+
concat_blocks = "\n".join(blocks)
|
| 12 |
+
if concat_blocks.strip():
|
| 13 |
+
content = concat_blocks
|
| 14 |
+
return content
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def clean_txt_content(content):
|
| 18 |
+
"""Remove comments and trim empty lines from txt content"""
|
| 19 |
+
lines = []
|
| 20 |
+
for line in content.split('\n'):
|
| 21 |
+
if not line.startswith('//'):
|
| 22 |
+
lines.append(line)
|
| 23 |
+
return '\n'.join(lines).strip()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_prompt(dirname, name):
|
| 27 |
+
txt_path = os.path.join(dirname, f"{name}.txt")
|
| 28 |
+
text = load_text(txt_path)
|
| 29 |
+
|
| 30 |
+
return text
|
| 31 |
+
|
| 32 |
+
def parse_prompt(text, **kwargs):
|
| 33 |
+
"""
|
| 34 |
+
从text中解析PromptMessages。
|
| 35 |
+
对于传入的key-values, key可以多也可以少。
|
| 36 |
+
少的key和value为空的那轮对话会被删除。
|
| 37 |
+
多的key不会管。
|
| 38 |
+
"""
|
| 39 |
+
content = clean_txt_content(text)
|
| 40 |
+
|
| 41 |
+
# Find all format keys in content using regex
|
| 42 |
+
format_keys = set(re.findall(r'\{(\w+)\}', content))
|
| 43 |
+
|
| 44 |
+
formatted_kwargs = {k: kwargs.get(k, '__delete__') or '__delete__' for k in format_keys}
|
| 45 |
+
formatted_kwargs = {k: f"```\n{v.strip()}\n```" for k, v in formatted_kwargs.items()}
|
| 46 |
+
prompt = content.format(**formatted_kwargs) if format_keys else content
|
| 47 |
+
messages = parse_chat(prompt)
|
| 48 |
+
for i in range(len(messages)-2, -1, -1):
|
| 49 |
+
if '__delete__' in messages[i]['content']:
|
| 50 |
+
assert messages[i]['role'] == 'user' and messages[i+1]['role'] == 'assistant', "__delete__ must be in user's message"
|
| 51 |
+
messages.pop(i)
|
| 52 |
+
messages.pop(i)
|
| 53 |
+
|
| 54 |
+
return messages
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def parse_input_keys(text):
|
| 58 |
+
# Use regex to find the input keys line and parse keys
|
| 59 |
+
match = re.search(r'//\s*输入:(.*?)(?:\n|$)', text)
|
| 60 |
+
if not match:
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
keys_str = match.group(1).strip()
|
| 64 |
+
|
| 65 |
+
keys = [k.strip() for k in keys_str.split(',') if k.strip()]
|
| 66 |
+
|
| 67 |
+
return keys
|
| 68 |
+
|
| 69 |
+
def main(model, dirname, user_prompt_text, **kwargs):
|
| 70 |
+
# Load system prompt
|
| 71 |
+
system_prompt = parse_prompt(load_prompt(dirname, "system_prompt"), **kwargs)
|
| 72 |
+
|
| 73 |
+
load_from_file_flag = False
|
| 74 |
+
if os.path.exists(os.path.join(dirname, user_prompt_text)):
|
| 75 |
+
user_prompt_text = load_prompt(dirname, user_prompt_text)
|
| 76 |
+
load_from_file_flag = True
|
| 77 |
+
else:
|
| 78 |
+
if not re.search(r'^user:\n', user_prompt_text, re.MULTILINE):
|
| 79 |
+
user_prompt_text = f"user:\n{user_prompt_text}"
|
| 80 |
+
|
| 81 |
+
user_prompt = parse_prompt(user_prompt_text, **kwargs)
|
| 82 |
+
|
| 83 |
+
context_input_keys = parse_input_keys(user_prompt_text)
|
| 84 |
+
if not context_input_keys:
|
| 85 |
+
assert not load_from_file_flag, "从本地文件加载Prompt时,本地文件中注释必须指明输入!"
|
| 86 |
+
context_kwargs = kwargs
|
| 87 |
+
else:
|
| 88 |
+
context_kwargs = {k: kwargs[k] for k in context_input_keys}
|
| 89 |
+
assert all(context_kwargs.values()), "Missing required context keys"
|
| 90 |
+
|
| 91 |
+
context_prompt = parse_prompt(load_prompt(dirname, "context_prompt"), **context_kwargs)
|
| 92 |
+
|
| 93 |
+
# Combine all prompts
|
| 94 |
+
final_prompt = system_prompt + context_prompt + user_prompt
|
| 95 |
+
|
| 96 |
+
# Chat and parse results
|
| 97 |
+
for response_msgs in chat(final_prompt, None, model, parse_chat=False):
|
| 98 |
+
text = parser(response_msgs)
|
| 99 |
+
ret = {'text': text, 'response_msgs': response_msgs}
|
| 100 |
+
yield ret
|
| 101 |
+
|
| 102 |
+
return ret
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
prompts/chat_utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from .pf_parse_chat import parse_chat as pf_parse_chat
|
| 3 |
+
|
| 4 |
+
from llm_api import ModelConfig, stream_chat
|
| 5 |
+
from datetime import datetime # Update this import
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def chat(messages, prompt, model:ModelConfig, parse_chat=False, response_json=False):
|
| 10 |
+
if prompt:
|
| 11 |
+
if parse_chat:
|
| 12 |
+
messages = pf_parse_chat(prompt)
|
| 13 |
+
else:
|
| 14 |
+
messages = messages + [{'role': 'user', 'content': prompt}]
|
| 15 |
+
|
| 16 |
+
result = yield from stream_chat(model, messages, response_json=response_json)
|
| 17 |
+
|
| 18 |
+
return result
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def log(prompt_name, prompt, parsed_result):
|
| 22 |
+
output_dir = os.path.join(os.path.dirname(__file__), 'output')
|
| 23 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
random_suffix = random.randint(1000, 9999)
|
| 26 |
+
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + f"_{prompt_name}_{random_suffix}.txt"
|
| 27 |
+
filepath = os.path.join(output_dir, filename)
|
| 28 |
+
|
| 29 |
+
response_msgs = parsed_result['response_msgs']
|
| 30 |
+
response = response_msgs.response
|
| 31 |
+
|
| 32 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 33 |
+
f.write("----------prompt--------------\n")
|
| 34 |
+
f.write(prompt + "\n\n")
|
| 35 |
+
f.write("----------response-------------\n")
|
| 36 |
+
f.write(response + "\n\n")
|
| 37 |
+
f.write("-----------parse----------------\n")
|
| 38 |
+
for k, v in parsed_result.items():
|
| 39 |
+
if k != 'response_msgs':
|
| 40 |
+
f.write(f"{k}:\n{v}\n\n")
|
prompts/common_parser.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def parse_content(response_msgs):
|
| 2 |
+
return response_msgs[-1]['content']
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def parse_last_code_block(response_msgs):
|
| 6 |
+
from prompts.prompt_utils import match_code_block
|
| 7 |
+
content = response_msgs.response
|
| 8 |
+
blocks = match_code_block(content)
|
| 9 |
+
if blocks:
|
| 10 |
+
content = blocks[-1]
|
| 11 |
+
return content
|
| 12 |
+
|
| 13 |
+
def parse_named_chunk(response_msgs, name):
|
| 14 |
+
from prompts.prompt_utils import parse_chunks_by_separators
|
| 15 |
+
content = response_msgs[-1]['content']
|
| 16 |
+
|
| 17 |
+
chunks = parse_chunks_by_separators(content, [r'\S*', ])
|
| 18 |
+
if name in chunks:
|
| 19 |
+
return chunks[name]
|
| 20 |
+
else:
|
| 21 |
+
return content
|
prompts/idea-examples.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
examples:
|
| 2 |
+
- idea: |-
|
| 3 |
+
随身携带王者荣耀召唤器,每天随机英雄三分钟附体!
|
| 4 |
+
- idea: |-
|
| 5 |
+
商业大亨穿越到哈利波特世界,但我不会魔法
|
| 6 |
+
- idea: |-
|
| 7 |
+
末日重生归来,我抢先把所有反派关在了地牢里
|
| 8 |
+
- idea: |-
|
| 9 |
+
身为高中生兼当红轻小说作家的我,正被年纪比我小且从事声优工作的女同学掐住脖子。
|
prompts/pf_parse_chat.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from typing import List, Mapping
|
| 7 |
+
|
| 8 |
+
from jinja2 import Template
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def validate_role(role: str, valid_roles: List[str] = None):
|
| 12 |
+
if not valid_roles:
|
| 13 |
+
valid_roles = ["assistant", "function", "user", "system"]
|
| 14 |
+
|
| 15 |
+
if role not in valid_roles:
|
| 16 |
+
valid_roles_str = ','.join([f'\'{role}:\\n\'' for role in valid_roles])
|
| 17 |
+
raise ValueError(f"Invalid role: {role}. Valid roles are: {valid_roles_str}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def try_parse_name_and_content(role_prompt):
|
| 21 |
+
# customer can add ## in front of name/content for markdown highlight.
|
| 22 |
+
# and we still support name/content without ## prefix for backward compatibility.
|
| 23 |
+
pattern = r"\n*#{0,2}\s*name:\n+\s*(\S+)\s*\n*#{0,2}\s*content:\n?(.*)"
|
| 24 |
+
match = re.search(pattern, role_prompt, re.DOTALL)
|
| 25 |
+
if match:
|
| 26 |
+
return match.group(1), match.group(2)
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_chat(chat_str, images: List = None, valid_roles: List[str] = None):
|
| 31 |
+
if not valid_roles:
|
| 32 |
+
valid_roles = ["system", "user", "assistant", "function"]
|
| 33 |
+
|
| 34 |
+
# openai chat api only supports below roles.
|
| 35 |
+
# customer can add single # in front of role name for markdown highlight.
|
| 36 |
+
# and we still support role name without # prefix for backward compatibility.
|
| 37 |
+
separator = r"(?i)^\s*#?\s*(" + "|".join(valid_roles) + r")\s*:\s*\n"
|
| 38 |
+
|
| 39 |
+
images = images or []
|
| 40 |
+
hash2images = {str(x): x for x in images}
|
| 41 |
+
|
| 42 |
+
chunks = re.split(separator, chat_str, flags=re.MULTILINE)
|
| 43 |
+
chat_list = []
|
| 44 |
+
|
| 45 |
+
for chunk in chunks:
|
| 46 |
+
last_message = chat_list[-1] if len(chat_list) > 0 else None
|
| 47 |
+
if last_message and "role" in last_message and "content" not in last_message:
|
| 48 |
+
parsed_result = try_parse_name_and_content(chunk)
|
| 49 |
+
if parsed_result is None:
|
| 50 |
+
# "name" is required if the role is "function"
|
| 51 |
+
if last_message["role"] == "function":
|
| 52 |
+
raise ValueError("Function role must have content.")
|
| 53 |
+
# "name" is optional for other role types.
|
| 54 |
+
else:
|
| 55 |
+
last_message["content"] = to_content_str_or_list(chunk, hash2images)
|
| 56 |
+
else:
|
| 57 |
+
last_message["name"] = parsed_result[0]
|
| 58 |
+
last_message["content"] = to_content_str_or_list(parsed_result[1], hash2images)
|
| 59 |
+
else:
|
| 60 |
+
if chunk.strip() == "":
|
| 61 |
+
continue
|
| 62 |
+
# Check if prompt follows chat api message format and has valid role.
|
| 63 |
+
# References: https://platform.openai.com/docs/api-reference/chat/create.
|
| 64 |
+
role = chunk.strip().lower()
|
| 65 |
+
validate_role(role, valid_roles=valid_roles)
|
| 66 |
+
new_message = {"role": role}
|
| 67 |
+
chat_list.append(new_message)
|
| 68 |
+
return chat_list
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def to_content_str_or_list(chat_str: str, hash2images: Mapping):
|
| 72 |
+
chat_str = chat_str.strip()
|
| 73 |
+
chunks = chat_str.split("\n")
|
| 74 |
+
include_image = False
|
| 75 |
+
result = []
|
| 76 |
+
for chunk in chunks:
|
| 77 |
+
if chunk.strip() in hash2images:
|
| 78 |
+
image_message = {}
|
| 79 |
+
image_message["type"] = "image_url"
|
| 80 |
+
image_url = hash2images[chunk.strip()].source_url \
|
| 81 |
+
if hasattr(hash2images[chunk.strip()], "source_url") else None
|
| 82 |
+
if not image_url:
|
| 83 |
+
image_bs64 = hash2images[chunk.strip()].to_base64()
|
| 84 |
+
image_mine_type = hash2images[chunk.strip()]._mime_type
|
| 85 |
+
image_url = {"url": f"data:{image_mine_type};base64,{image_bs64}"}
|
| 86 |
+
image_message["image_url"] = image_url
|
| 87 |
+
result.append(image_message)
|
| 88 |
+
include_image = True
|
| 89 |
+
elif chunk.strip() == "":
|
| 90 |
+
continue
|
| 91 |
+
else:
|
| 92 |
+
result.append({"type": "text", "text": chunk})
|
| 93 |
+
return result if include_image else chat_str
|
| 94 |
+
|
prompts/prompt_utils.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import difflib
|
| 2 |
+
import json
|
| 3 |
+
import yaml
|
| 4 |
+
import chardet
|
| 5 |
+
from jinja2 import Environment, FileSystemLoader
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import sys, os
|
| 9 |
+
root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
|
| 10 |
+
if root_path not in sys.path:
|
| 11 |
+
sys.path.append(root_path)
|
| 12 |
+
|
| 13 |
+
from llm_api.chat_messages import ChatMessages
|
| 14 |
+
|
| 15 |
+
def can_parse_json(response):
|
| 16 |
+
try:
|
| 17 |
+
json.loads(response)
|
| 18 |
+
return True
|
| 19 |
+
except:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
def match_first_json_block(response):
|
| 23 |
+
if can_parse_json(response):
|
| 24 |
+
return response
|
| 25 |
+
|
| 26 |
+
pattern = r"(?<=[\r\n])```json(.*?)```(?=[\r\n])"
|
| 27 |
+
matches = re.findall(pattern, '\n' + response + '\n', re.DOTALL)
|
| 28 |
+
if not matches:
|
| 29 |
+
pattern = r"(?<=[\r\n])```(.*?)```(?=[\r\n])"
|
| 30 |
+
matches = re.findall(pattern, '\n' + response + '\n', re.DOTALL)
|
| 31 |
+
|
| 32 |
+
if matches:
|
| 33 |
+
json_block = matches[0]
|
| 34 |
+
if can_parse_json(json_block):
|
| 35 |
+
return json_block
|
| 36 |
+
else:
|
| 37 |
+
json_block = json_block.replace('\r\n', '') # 在continue generate情况下,不同部分之间可能有多出的换行符,导致合起来之后json解析失败
|
| 38 |
+
if can_parse_json(json_block):
|
| 39 |
+
return json_block
|
| 40 |
+
else:
|
| 41 |
+
raise Exception(f"无法解析JSON代码块")
|
| 42 |
+
else:
|
| 43 |
+
raise Exception(f"没有匹配到JSON代码块")
|
| 44 |
+
|
| 45 |
+
def parse_first_json_block(response_msgs: ChatMessages):
|
| 46 |
+
assert response_msgs[-1]['role'] == 'assistant'
|
| 47 |
+
return json.loads(match_first_json_block(response_msgs[-1]['content']))
|
| 48 |
+
|
| 49 |
+
def match_code_block(response):
|
| 50 |
+
response = re.sub(r'\r\n', r'\n', response)
|
| 51 |
+
response = re.sub(r'\r', r'\n', response)
|
| 52 |
+
pattern = r"```(?:\S*\s)(.*?)```"
|
| 53 |
+
matches = re.findall(pattern, response + '```', re.DOTALL)
|
| 54 |
+
return matches
|
| 55 |
+
|
| 56 |
+
def json_dumps(json_object):
|
| 57 |
+
return json.dumps(json_object, ensure_ascii=False, indent=1)
|
| 58 |
+
|
| 59 |
+
def parse_chunks_by_separators(string, separators):
|
| 60 |
+
separator_pattern = r"^\s*###\s*(" + "|".join(separators) + r")\s*\n"
|
| 61 |
+
|
| 62 |
+
chunks = re.split(separator_pattern, string, flags=re.MULTILINE)
|
| 63 |
+
|
| 64 |
+
ret = {}
|
| 65 |
+
|
| 66 |
+
current_title = None
|
| 67 |
+
|
| 68 |
+
for i, chunk in enumerate(chunks):
|
| 69 |
+
if i % 2 == 1:
|
| 70 |
+
current_title = chunk.strip()
|
| 71 |
+
ret[current_title] = ""
|
| 72 |
+
elif current_title:
|
| 73 |
+
ret[current_title] += chunk.strip()
|
| 74 |
+
|
| 75 |
+
return ret
|
| 76 |
+
|
| 77 |
+
def construct_chunks_and_separators(chunk2separator):
|
| 78 |
+
return "\n\n".join([f"### {k}\n{v}" for k, v in chunk2separator.items()])
|
| 79 |
+
|
| 80 |
+
def match_chunk_span_in_text(chunk, text):
|
| 81 |
+
diff = difflib.Differ().compare(chunk, text)
|
| 82 |
+
|
| 83 |
+
chunk_i = 0
|
| 84 |
+
text_i = 0
|
| 85 |
+
|
| 86 |
+
for tag in diff:
|
| 87 |
+
if tag.startswith(' '):
|
| 88 |
+
chunk_i += 1
|
| 89 |
+
text_i += 1
|
| 90 |
+
elif tag.startswith('+'):
|
| 91 |
+
text_i += 1
|
| 92 |
+
else:
|
| 93 |
+
chunk_i += 1
|
| 94 |
+
|
| 95 |
+
if chunk_i == 1:
|
| 96 |
+
l = text_i - 1
|
| 97 |
+
|
| 98 |
+
if chunk_i == len(chunk):
|
| 99 |
+
r = text_i
|
| 100 |
+
return l, r
|
| 101 |
+
|
| 102 |
+
def load_yaml(file_path):
|
| 103 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 104 |
+
return yaml.safe_load(file)
|
| 105 |
+
|
| 106 |
+
def load_text(file_path, read_size=None):
|
| 107 |
+
# Read the raw bytes first
|
| 108 |
+
with open(file_path, 'rb') as file:
|
| 109 |
+
raw_data = file.read(read_size)
|
| 110 |
+
|
| 111 |
+
# Detect the encoding
|
| 112 |
+
result = chardet.detect(raw_data[:10000])
|
| 113 |
+
encoding = result['encoding'] or 'utf-8' # Fallback to utf-8 if detection fails
|
| 114 |
+
|
| 115 |
+
# Decode the content with detected encoding
|
| 116 |
+
try:
|
| 117 |
+
return raw_data.decode(encoding, errors='ignore')
|
| 118 |
+
except UnicodeDecodeError:
|
| 119 |
+
# Fallback to utf-8 if the detected encoding fails
|
| 120 |
+
return raw_data.decode('utf-8', errors='ignore')
|
| 121 |
+
|
| 122 |
+
def load_jinja2_template(file_path):
|
| 123 |
+
env = Environment(loader=FileSystemLoader(os.path.dirname(file_path)))
|
| 124 |
+
template = env.get_template(os.path.basename(file_path))
|
| 125 |
+
|
| 126 |
+
return template
|
| 127 |
+
|
| 128 |
+
|
prompts/test_format_plot.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- |-
|
| 2 |
+
李珣呆立不动,背后传来水声,那雾中的女子在悠闲洗浴。
|
| 3 |
+
|
| 4 |
+
李珣对这种场景感到震惊,认为她绝非普通人,决定乖乖表现。
|
| 5 |
+
|
| 6 |
+
尽管转身,他仍紧闭眼睛,慌乱道歉。
|
| 7 |
+
|
| 8 |
+
那女子静默片刻,继续泼水声令李珣难以忍受。
|
| 9 |
+
|
| 10 |
+
她随后淡然问话,李珣感到对方危险可怕。
|
| 11 |
+
|
| 12 |
+
她询问李珣怎么上山,李珣答“爬上来的”,这让对方略显惊讶。
|
| 13 |
+
|
| 14 |
+
女子探问他身份,李珣庆幸自己内息如同名门,为保命决定实话实说,自报身份并讲述过往经历,隐去危险细节。
|
| 15 |
+
|
| 16 |
+
这番表白得到女子的肯定,虽然语调淡然,但意思清晰。她让李珣暂时离开,待她穿戴整齐。
|
| 17 |
+
|
| 18 |
+
李珣照做,在岸边等候。女子走出雾气,身姿曼妙,令他看呆。
|
| 19 |
+
|
| 20 |
+
铃声伴着她的步伐,让李珣心神为之所牵。
|
| 21 |
+
|
| 22 |
+
当水气散尽,绝美之貌让李珣惊叹不已,几乎想要顶礼膜拜。
|
| 23 |
+
- |-
|
| 24 |
+
隐隐间,似乎有一丝若有若无的铃声,缓缓地沁入水雾之中,与这迷茫天水交织在一处,细碎的抖颤之声,天衣无缝地和这缓步而来的身影合在一处,攫牢了李珣的心神。
|
| 25 |
+
而当眼前水气散尽,李珣更是连呼吸都停止了。此为何等佳人?
|
| 26 |
+
李珣只觉得眼前洁净不沾一尘的娇颜,便如一朵临水自照的水仙,清丽中别有孤傲,闲适中却见轻愁。
|
| 27 |
+
他还没找到形容眼前佳人的辞句,便已觉得两腿发软,恨不能跪倒地上,顶礼膜拜。
|
| 28 |
+
|
prompts/test_prompt.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys, os
|
| 3 |
+
root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
|
| 4 |
+
sys.path.append(root_path)
|
| 5 |
+
|
| 6 |
+
from prompts.load_utils import run_prompt
|
| 7 |
+
|
| 8 |
+
def json_load(input_file):
|
| 9 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 10 |
+
if input_file.endswith('.jsonl'):
|
| 11 |
+
return [json.loads(line) for line in f.readlines()]
|
| 12 |
+
else:
|
| 13 |
+
return json.load(f)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
path = "./prompts/创作正文"
|
| 18 |
+
kwargs = json_load(os.path.join(path, 'data.jsonl'))[0]
|
| 19 |
+
|
| 20 |
+
gen = run_prompt(source=path, **kwargs)
|
| 21 |
+
|
| 22 |
+
list(gen)
|
prompts/text-plot-examples.yaml
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prompt: |-
|
| 2 |
+
逐句简写下面的小说正文。如果原句本来就很少,考虑将原文多个(2-5个)句子简写为一个。
|
| 3 |
+
examples:
|
| 4 |
+
- title: 青吟
|
| 5 |
+
text: |-
|
| 6 |
+
李珣呆立当场,手足无措。
|
| 7 |
+
后方水声不止,那位雾后佳人并未停下动作,还在那里撩水净身。
|
| 8 |
+
李珣听得有些傻了,虽然他对异性的认识不算全面,可是像后面这位,能够在男性身旁悠闲沐浴的,是不是也稀少了一些?
|
| 9 |
+
李珣毕竟不傻,他此时也已然明白,现在面对的是一位绝对惹不起的人物,在这种强势人物眼前,做一个乖孩子,是最聪明不过的了!
|
| 10 |
+
他虽已背过身来,却还是紧闭眼睛,生怕无意间又冒犯了人家,这无关道德风化,仅仅是为了保住小命而已。
|
| 11 |
+
确认了一切都已稳妥,他这才结结巴巴地开口:“对……对不住,我不是……故意的!”
|
| 12 |
+
对方并没有即时回答,李珣只听到哗哗的泼水声,每一点声息,都是对他意志的摧残。
|
| 13 |
+
也不知过了多久,雾后的女子开口了:“话是真的,却何必故作紧张?事不因人而异,一个聪明人和一个蠢材,要承担的后果都是一样的。”
|
| 14 |
+
李珣顿时哑口无言。
|
| 15 |
+
后面这女人,实在太可怕了。
|
| 16 |
+
略停了一下,这女子又道:“看你修为不济,也御不得剑,是怎么上来这里的?”
|
| 17 |
+
李珣脱口道:“爬上来的!”
|
| 18 |
+
“哦?”女子的语气中第一次有了情绪存在,虽只是一丝淡淡的惊讶,却也让李珣颇感自豪。只听她问道:“你是明心剑宗的弟子?”
|
| 19 |
+
这算是盘问身分了。李珣首先庆幸他此时内息流转的形式,是正宗的明心剑宗嫡传。否则,幽明气一出,恐怕对面之人早一掌劈了他!
|
| 20 |
+
庆幸中,他的脑子转了几转,将各方面的后果都想了一遍,终是决定“据实”以告。
|
| 21 |
+
“惭愧,只是个不入流的低辈弟子……”
|
| 22 |
+
李珣用这句话做缓冲,随即便从自己身世说起,一路说到登峰七年的经历。
|
| 23 |
+
当然,其中关于血散人的死亡威胁,以及近日方得到的《幽冥录》等,都略去不提。只说是自己一心向道,被淘汰之后,便去爬坐忘峰以证其心云云。
|
| 24 |
+
这段话本是他在心中温养甚久,准备做为日后说辞使用,虽然从未对人道过,但腹中已是熟练至极。
|
| 25 |
+
初时开口,虽然还有些辞语上的生涩,但到后来,已是流利无比,许多词汇无需再想,便脱口而出,却是再“真诚”不过。
|
| 26 |
+
他一开口,说了足足有一刻钟的工夫,这当中,那女子也问了几句细节,却也都在李珣计画之内,回应得也颇为顺畅。
|
| 27 |
+
如此,待他告一段落之时,那女人竟让他意外地道了一声:“如今竟也有这般人物!”
|
| 28 |
+
语气虽然还是平平淡淡的,像是在陈述毫不出奇的一件平凡事,但其中意思却是到了。李珣心中暗喜,口中当然还要称谢。
|
| 29 |
+
女子也不在乎他如何反应,只是又道一声:“你孤身登峰七年,行程二十余万里,能承受这种苦楚,也算是人中之杰。我这样对你,倒是有些不敬,你且左行百步上岸,待我穿戴整齐,再与你相见。”
|
| 30 |
+
李珣自是依言而行,上了岸去,也不敢多话,只是恭立当场,面上作了十足工夫。
|
| 31 |
+
也只是比他晚个数息时间,一道人影自雾气中缓缓走来,水烟流动,轻云伴生,虽仍看不清面目,但她凌波微步,长裙摇曳的体态,却已让李珣看呆了眼,只觉得此生再没见过如此人物。
|
| 32 |
+
隐隐间,似乎有一丝若有若无的铃声,缓缓地沁入水雾之中,与这迷茫天水交织在一处,细碎的抖颤之声,天衣无缝地和这缓步而来的身影合在一处,攫牢了李珣的心神。
|
| 33 |
+
而当眼前水气散尽,李珣更是连呼吸都停止了。此为何等佳人?
|
| 34 |
+
李珣只觉得眼前洁净不沾一尘的娇颜,便如一朵临水自照的水仙,清丽中别有孤傲,闲适中却见轻愁。
|
| 35 |
+
他还没找到形容眼前佳人的辞句,便已觉得两腿发软,恨不能跪倒地上,顶礼膜拜。
|
| 36 |
+
plot: |-
|
| 37 |
+
李珣呆立不动,背后传来水声,那雾中的女子在悠闲洗浴。
|
| 38 |
+
|
| 39 |
+
李珣对这种场景感到震惊,认为她绝非普通人,决定乖乖表现。
|
| 40 |
+
|
| 41 |
+
尽管转身,他仍紧闭眼睛,慌乱道歉。
|
| 42 |
+
|
| 43 |
+
那女子静默片刻,继续泼水声令李珣难以忍受。
|
| 44 |
+
|
| 45 |
+
她随后淡然问话,李珣感到对方危险可怕。
|
| 46 |
+
|
| 47 |
+
她询问李珣怎么上山,李珣答“爬上来的”,这让对方略显惊讶。
|
| 48 |
+
|
| 49 |
+
女子探问他身份,李珣庆幸自己内息如同名门,为保命决定实话实说,自报身份并讲述过往经历,隐去危险细节。
|
| 50 |
+
|
| 51 |
+
这番表白得到女子的肯定,虽然语调淡然,但意思清晰。她让李珣暂时离开,待她穿戴整齐。
|
| 52 |
+
|
| 53 |
+
李珣照做,在岸边等候。女子走出雾气,身姿曼妙,令他看呆。
|
| 54 |
+
|
| 55 |
+
铃声伴着她的步伐,让李珣心神为之所牵。
|
| 56 |
+
|
| 57 |
+
当水气散尽,绝美之貌让李珣惊叹不已,几乎想要顶礼膜拜。
|
| 58 |
+
- title: 纳兰嫣然
|
| 59 |
+
text: |-
|
| 60 |
+
云岚宗后山山巅,云雾缭绕,宛如仙境。
|
| 61 |
+
|
| 62 |
+
在悬崖边缘处的一块凸出的黑色岩石之上,身着月白色裙袍的女子,正双手结出修炼的印结,闭目修习,而随着其一呼一吸间,形成完美的循环,在每次循环的交替间,周围能量浓郁的空气中都将会渗发出一股股淡淡的青色气流,气流盘旋在女子周身,然后被其源源不断的吸收进身体之内,进行着炼化,收纳……
|
| 63 |
+
|
| 64 |
+
当最后一缕青色气流被女子吸进身体之后,她缓缓的睁开双眸,淡淡的青芒从眸子中掠过,披肩的青丝,霎那间无风自动,微微飞扬。
|
| 65 |
+
|
| 66 |
+
“纳兰师姐,纳兰肃老爷子来云岚宗了,他说让你去见他。”
|
| 67 |
+
|
| 68 |
+
见到女子退出了修炼状态,一名早已经等待在此处的侍女,急忙恭声道。
|
| 69 |
+
|
| 70 |
+
“父亲?他来做什么?”
|
| 71 |
+
|
| 72 |
+
闻言,女子黛眉微皱,疑惑的摇了摇头,优雅的站起身子,立于悬崖之边,迎面而来的轻风。将那月白裙袍吹得紧紧的贴在女子玲珑娇躯之上,显得凹凸有致,极为诱人。
|
| 73 |
+
|
| 74 |
+
目光慵懒的在深不见底的山崖下扫了扫,女子玉手轻拂了拂月白色的裙袍,旋即转身离开了这处她专用的修炼之所。
|
| 75 |
+
|
| 76 |
+
宽敞明亮地大厅之中。一名脸色略微有些阴沉地中年人,正端着茶杯。放在桌上的手掌,有些烦躁地不断敲打着桌面。
|
| 77 |
+
|
| 78 |
+
纳兰肃现在很烦躁,因为他几乎是被他的父亲纳兰桀用棍子撵上的云岚宗。
|
| 79 |
+
|
| 80 |
+
他没想到,他仅仅是率兵去帝国西部驻扎了一年而已。自己这个胆大包天的女儿,竟然就敢私自把当年老爷子亲自定下的婚事给推了。
|
| 81 |
+
|
| 82 |
+
家族之中,谁不知道纳兰桀极其要面子。而纳兰嫣然现在的这举动,无疑会让别人说成是他纳兰家看见萧家势力减弱,不屑与之联婚,便毁信弃诺。
|
| 83 |
+
|
| 84 |
+
这种闲言碎语,让得纳兰桀每天都在家中暴跳如雷。若不是因为动不了身的缘故。恐怕他早已经拖着那行将就木的身体,来爬云岚山了。
|
| 85 |
+
|
| 86 |
+
对于纳兰家族与萧家的婚事。说实在的,其实纳兰肃也并不太赞成。毕竟当初的萧炎,几乎是废物的代名词。让他将自己这容貌与修炼天赋皆是上上之选的女儿嫁给一个废物。纳兰肃心中还真是一百个不情愿。
|
| 87 |
+
|
| 88 |
+
不过,当初是当初,根据他所得到的消息,现在萧家的那小子,不仅脱去了废物的名头,而且所展现出来的修炼速度,几乎比他小时候最巅峰的时候还要恐怖。
|
| 89 |
+
|
| 90 |
+
此时萧炎所表现而出的潜力,无疑已经能够让得纳兰肃重视。然而,纳兰嫣然的私自举动,却是把双方的关系搞成了冰冷的僵局,这让得纳兰肃极为的尴尬。
|
| 91 |
+
|
| 92 |
+
按照这种关系下去,搞不好,他纳兰肃不仅会失去一个潜力无限的女婿,而且说不定还会因此让得他对纳兰家族怀恨在心。
|
| 93 |
+
|
| 94 |
+
只要想着一个未来有机会成为斗皇的强者或许会敌视着纳兰家族,纳兰肃在后怕之余,便是气得直跳脚。
|
| 95 |
+
|
| 96 |
+
“这丫头。现在胆子是越来越大了……”
|
| 97 |
+
|
| 98 |
+
越想越怒,纳兰肃手中的茶杯忽然重重的跺在桌面之上,茶水溅了满桌。将一旁侍候的侍女吓了一跳,赶忙小心翼翼的再次换了一杯。云岚宗,怎么不通知一下焉儿啊?”
|
| 99 |
+
|
| 100 |
+
就在纳兰肃心头发怒之时,女子清脆的声音,忽然地在大厅内响起,月白色的倩影,从纱帘中缓缓行出,对着纳兰肃甜甜笑道。
|
| 101 |
+
|
| 102 |
+
“哼,你眼里还有我这个父亲?我以为你成为了云韵的弟子,就不知道什么是纳兰家族了呢!”望着这出落得越来越水灵的女儿,纳兰肃心头的怒火稍稍收敛了一点,冷哼道。
|
| 103 |
+
|
| 104 |
+
瞧着纳兰肃不甚好看的脸色,纳兰嫣然无奈地摇了摇头,对着那一旁的侍女挥了挥手,将之遣出。
|
| 105 |
+
|
| 106 |
+
“父亲,一年多不见,你一来就训斥焉儿,等下次回去,我可一定要告诉母亲!”待得侍女退出之后,纳兰嫣然顿时皱起了俏鼻,在纳兰肃身旁坐下,撒娇般的哼道。
|
| 107 |
+
|
| 108 |
+
“回去?你还敢回去?”闻言,纳兰肃嘴角一裂:“你敢回去,看你爷爷敢不敢打断你的腿……”
|
| 109 |
+
|
| 110 |
+
撇了撇嘴,心知肚明的纳兰嫣然,自然清楚纳兰肃话中的意思。
|
| 111 |
+
|
| 112 |
+
“你应该知道我来此处的目的吧?”
|
| 113 |
+
|
| 114 |
+
狠狠的灌了一��茶水,纳兰肃阴沉着脸道。
|
| 115 |
+
|
| 116 |
+
“是为了我悔婚的事吧?”
|
| 117 |
+
|
| 118 |
+
纤手把玩着一缕青丝,纳兰嫣然淡淡地道。
|
| 119 |
+
|
| 120 |
+
看着纳兰嫣然这平静的模样,纳兰肃顿时被气乐了,手掌重重地拍在桌上,怒声道:“婚事是你爷爷当年亲自允下的,是谁让你去解除的?”
|
| 121 |
+
|
| 122 |
+
“那是我的婚事,我才不要按照你们的意思嫁给谁,我的事,我自己会做主!我不管是谁允下的,我只知道,如果按照约定。嫁过去的是我,不是爷爷!”提起这事,纳兰嫣然也是脸现不愉,性子有些独立的她,很讨厌自己的大事按照别人所指定的路线行走。即使这人是她的长辈。
|
| 123 |
+
|
| 124 |
+
“你别以为我不知道,你无非是认为萧炎当初一个废物配不上你是吧?可现在人家潜力不会比你低!以你在云岚宗的地位,应该早就接到过有关他实力提升的消息吧?”纳兰肃怒道。
|
| 125 |
+
|
| 126 |
+
纳兰嫣然黛眉微皱,脑海中浮现当年那充满着倔性的少年,红唇微抿,淡淡地道:“的确听说过一些关于他的消息,没想到,他竟然还真的能脱去废物的名头,这倒的确让我很意外。”
|
| 127 |
+
|
| 128 |
+
“意外?一句意外就行了?你爷爷开口了。让你找个时间,再去一趟乌坦城,最好能道个歉把僵硬的关系弄缓和一些。”纳兰肃皱眉道。
|
| 129 |
+
|
| 130 |
+
“道歉?不可能!”
|
| 131 |
+
|
| 132 |
+
闻言,纳兰嫣然柳眉一竖,毫不犹豫地直接拒绝,冷哼道:“他萧炎虽然不再是废物,可我纳兰嫣然依然不会嫁给他!更别提让我去道什么歉,你们喜欢,那就自己去,反正我不会再去乌坦城!”
|
| 133 |
+
|
| 134 |
+
“这哪有你回绝的余地!祸是你闯的,你必须去给我了结了!”瞧得纳兰嫣然竟然一口回绝,纳兰肃顿时勃然大怒。
|
| 135 |
+
|
| 136 |
+
“不去!”
|
| 137 |
+
|
| 138 |
+
冷着俏脸,纳兰嫣然扬起雪白的下巴,脸颊上有着一抹与生俱来的娇贵:“他萧炎不是很有本事么?既然当年敢应下三年的约定,那我纳兰嫣然就在云岚宗等着他来挑战,若是我败给他,为奴为婢,随他处置便是,哼,如若不然,想要我道歉。不可能!”
|
| 139 |
+
|
| 140 |
+
“混账,如果三年约定,你最后输了,到时候为奴为婢,那岂不是连带着我纳兰家族,也把脸给丢光了?”纳兰肃怒斥道。
|
| 141 |
+
|
| 142 |
+
“谁说我会输给他?就算他萧炎回复了天赋,可我纳兰嫣然难道会差了他什么不成?而且云岚宗内高深功法不仅数不胜数,高级斗技更是收藏丰厚,更有丹王古河爷爷帮我炼制丹药。这些东西。他一个小家族的少爷难道也能有?说句不客气的,恐怕光光是寻找高级斗气功法。就能让得他花费好十几年时间!”被纳兰肃这般小瞧,纳兰嫣然顿时犹如被踩到尾巴的母猫一般,她最讨厌的,便是被人说成比不上那曾经被自己万般看不起的废物!
|
| 143 |
+
|
| 144 |
+
被女儿当着面这般吵闹,纳兰肃气得吹胡子瞪眼,猛然站起身来,扬起手掌就欲对着纳兰嫣然扇下去。
|
| 145 |
+
|
| 146 |
+
“纳兰兄,你可不要乱来啊。”瞧着纳兰肃的动作,一道白影急忙掠了进来,挡在了纳兰嫣然面前。
|
| 147 |
+
|
| 148 |
+
“葛叶,你这个混蛋,听说上次去萧家,还是你陪地嫣然?”望着挡在面前的人影,纳兰肃更是怒气暴涨,大怒道。
|
| 149 |
+
|
| 150 |
+
尴尬一笑,葛叶苦笑道:“这是宗主的意思,我也没办法。”
|
| 151 |
+
plot: |-
|
| 152 |
+
尴尬一笑,葛叶苦笑道:“这是宗主的意思,我也没办法。”
|
| 153 |
+
|
| 154 |
+
云岚宗后山,云雾缭绕。月白裙袍的女子在山巅悬崖边修炼,吸收青色气流。
|
| 155 |
+
|
| 156 |
+
当她吸收完最后一缕气流后,睁开双眸,青芒掠过,青丝微动。一名侍女走上前恭敬道:“纳兰师姐,纳兰肃老爷子来了,让你过去见他。”
|
| 157 |
+
|
| 158 |
+
女子黛眉微皱,疑惑地站起身,转身离开修炼之所。大厅内,中年人纳兰肃端着茶杯,脸色阴沉,不断敲打桌面。
|
| 159 |
+
|
| 160 |
+
纳兰肃被父亲纳兰桀用棍子赶上山来,因为女儿纳兰嫣然私自退了婚约,纳兰家族因此陷入困境。萧炎本是废物,但现在展现出强大潜力,纳兰肃对此很重视,但女儿的举动让他尴尬。
|
| 161 |
+
|
| 162 |
+
纳兰嫣然出现,父女二人开战言语。纳兰肃火冒三丈,纳兰嫣然反对重新接触萧炎,认为只有她自己可以决定自己的婚事。
|
| 163 |
+
|
| 164 |
+
纳兰肃怒斥,纳兰嫣然强硬回击,表示她不会道歉,只会等待萧炎挑战她。如果她输了,愿意为奴为婢,但她相信自己不会输。
|
| 165 |
+
|
| 166 |
+
纳兰肃气愤欲扇女儿耳光,一道白影葛叶及时挡住,纳兰肃更怒,葛叶苦笑解释这是宗主的意思。
|
| 167 |
+
- title: 极阴老祖
|
| 168 |
+
text: |-
|
| 169 |
+
“而且你真以为,你能做得了主吗?老怪物,你也不用躲躲藏藏了,快点现身吧!”中年人阴厉的说��。
|
| 170 |
+
|
| 171 |
+
听了这话,韩立等修士吓了一大跳,急忙往四处张望了起来。难道极阴老祖就在这里?
|
| 172 |
+
|
| 173 |
+
可是四周仍然平静如常,并没有什么异常出现。这下众修士有些摸不着头脑了,再次往中年人和乌丑望去。
|
| 174 |
+
|
| 175 |
+
“你搞什么鬼?我怎么做不了……”乌丑一开始也有些愕然,但话只说了一半时神色一滞,并开始露出了一丝古怪的神色。
|
| 176 |
+
|
| 177 |
+
他用这种神色直直的盯着中年人片刻后,诡异的笑了起来。“不错,不错!不愧为我当年最看重的弟子之一,竟然一眼就看出老夫的身份来了。”
|
| 178 |
+
|
| 179 |
+
说话之间,乌丑的面容开始模糊扭曲了起来,不一会儿后,就在众人惊诧的目光中,化为了一个同样瘦小,却两眼微眯的丑陋老者。
|
| 180 |
+
|
| 181 |
+
这下,韩立等人后背直冒寒气。
|
| 182 |
+
|
| 183 |
+
“附身大法!我就知道,你怎会将如此重要的事情交予一个晚辈去做,还是亲自来了。尽管这不是你的本体。”中年人神色紧张的瞅向老者,声音却低缓的说道。
|
| 184 |
+
|
| 185 |
+
“乖徒弟,你还真敢和为师动手不成?”新出现的老者嘴唇未动一下,却从腹部发出尖锐之极的声音,刺得众人的耳膜隐隐作痛,所有人都情不自禁的后退了几步。
|
| 186 |
+
|
| 187 |
+
“哼!徒弟?当年你对我们打杀任凭一念之间,稍有不从者,甚至还要抽神炼魂,何曾把我们当过徒弟看待!只不过是你的奴隶罢了!而且,你现在只不过施展的是附身之术而已,顶多能发挥三分之一的修为,我有什么可惧的!”中年人森然的说道,随后两手一挥,身前的鬼头凭空巨涨了起来,瞬间变得更加狰狞可怖起来。
|
| 188 |
+
|
| 189 |
+
紫灵仙子和韩立等修士,则被这诡异的局面给震住了,一时间神色各异!
|
| 190 |
+
|
| 191 |
+
老者听了中年人的话,并没有动怒,反而淡淡的说道:“不错,若是百余年前,你说这话的确没错!凭我三分之一的修为,想要活捉你还真有些困难。但是现在……”
|
| 192 |
+
|
| 193 |
+
说到这里时,他露出了一丝尖刻的讥笑之意。
|
| 194 |
+
|
| 195 |
+
第四卷 风起海外 第四百零六章 天都尸火
|
| 196 |
+
|
| 197 |
+
中年人听了老者的话,眼中神光一缩,露出难以置信的神情。
|
| 198 |
+
|
| 199 |
+
“难道你练成了那魔功?”他的声音有些惊惧。
|
| 200 |
+
|
| 201 |
+
“你猜出来更好,如果现在乖乖束手就擒的话,我还能放你一条活路。否则后果怎样,不用我说你应该也知道才对。”老者一边说着,一边伸出一只手掌,只听“嗤啦”一声,一团漆黑如墨的火球漂浮在了手心之上。
|
| 202 |
+
|
| 203 |
+
“天都尸火!你终于练成了。”中年人的脸色灰白无比,声音发干的说道,竟惊骇的有点嘶哑了。
|
| 204 |
+
|
| 205 |
+
见此情景,极阴祖师冷笑了一声,忽然转过头来,对紫灵仙子等人傲然的说道:“你们听好了,本祖师今天心情很好,可以放你们一条活路!只要肯从此归顺极阴岛,你们还可以继续的逍遥自在。但是本祖师下达的命令必须老老实实的完成,否则就是魂飞魄散的下场。现在在这些禁神牌上交出你们三分之一的元神,就可以安然离去了。”说完这话,他另一只手往怀内一摸,掏出了数块漆黑的木牌,冷冷的望着众人。
|
| 206 |
+
|
| 207 |
+
韩立和其他的修士听了,面面相觑起来。既没有人蠢到主动上前去接此牌,也没人敢壮起胆子说不接,摄于对方的名头,一时场中鸦雀无声。
|
| 208 |
+
plot: |-
|
| 209 |
+
“你以为你能做主吗?老怪物,现身吧!”中年人阴冷道。
|
| 210 |
+
|
| 211 |
+
韩立等人吓了一跳,四处张望,但周围平静,他们再次看向中年人和乌丑。
|
| 212 |
+
|
| 213 |
+
乌丑开始糊涂,但转而露出怪异表情,说:“不错,你看出了我的身份。”随后,乌丑的面容扭曲,变成一个瘦小丑陋的老者,韩立等人惊恐不已。
|
| 214 |
+
|
| 215 |
+
“附身大法!我就知道你会亲自来。”中年人低声说。
|
| 216 |
+
|
| 217 |
+
老者发出尖锐声音道:“你敢和我动手?”
|
| 218 |
+
|
| 219 |
+
中年人冷笑:“当年你视我们为奴隶。你现在只施展了附身之术,有什么可惧!”随即,鬼头变得更加狰狞。
|
| 220 |
+
|
| 221 |
+
老者淡然道:“若百年前你说的对,但现在……”露出讥笑。
|
| 222 |
+
|
| 223 |
+
中年人惊恐道:“难道你练成了那魔功?”
|
| 224 |
+
|
| 225 |
+
老者冷笑,召出一团漆黑的火球:“天都尸火!现在束手就擒,否则后果自负。”转头对紫灵仙子等人道:“归顺极阴岛,交出三分之一元神,否则魂飞魄散。”掏出数块黑色木牌。
|
| 226 |
+
|
| 227 |
+
韩立等人面面相觑,没人敢动,也不敢拒绝,场中一片沉寂。
|
prompts/tool_parser.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from promptflow.core import tool
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ResponseType(str, Enum):
|
| 6 |
+
CONTENT = "content"
|
| 7 |
+
SEPARATORS = "separators"
|
| 8 |
+
CODEBLOCK = "codeblock"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import sys, os
|
| 12 |
+
root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
|
| 13 |
+
if root_path not in sys.path:
|
| 14 |
+
sys.path.append(root_path)
|
| 15 |
+
|
| 16 |
+
# The inputs section will change based on the arguments of the tool function, after you save the code
|
| 17 |
+
# Adding type to arguments and return value will help the system show the types properly
|
| 18 |
+
# Please update the function name/signature per need
|
| 19 |
+
@tool
|
| 20 |
+
def parse_response(response_msgs, response_type: Enum):
|
| 21 |
+
from prompts.prompt_utils import parse_chunks_by_separators, match_code_block
|
| 22 |
+
|
| 23 |
+
content = response_msgs[-1]['content']
|
| 24 |
+
|
| 25 |
+
if response_type == ResponseType.CONTENT:
|
| 26 |
+
return content
|
| 27 |
+
elif response_type == ResponseType.CODEBLOCK:
|
| 28 |
+
codeblock = match_code_block(content)
|
| 29 |
+
|
| 30 |
+
if codeblock:
|
| 31 |
+
return codeblock[-1]
|
| 32 |
+
else:
|
| 33 |
+
raise Exception("无法解析回答,未包含三引号代码块。")
|
| 34 |
+
|
| 35 |
+
elif response_type == ResponseType.SEPARATORS:
|
| 36 |
+
chunks = parse_chunks_by_separators(content, [r'\S*', ])
|
| 37 |
+
return chunks
|
| 38 |
+
else:
|
| 39 |
+
raise Exception(f"无效的解析类型:{response_type}")
|
prompts/tool_polish.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import path
|
| 2 |
+
from promptflow.core import tool, load_flow
|
| 3 |
+
|
| 4 |
+
import sys, os
|
| 5 |
+
root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
|
| 6 |
+
if root_path not in sys.path:
|
| 7 |
+
sys.path.append(root_path)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@tool
|
| 11 |
+
def polish(messages, context, model, config, text):
|
| 12 |
+
source = path.join(path.dirname(path.abspath(__file__)), "./polish")
|
| 13 |
+
flow = load_flow(source=source)
|
| 14 |
+
|
| 15 |
+
return flow(
|
| 16 |
+
chat_messages=messages,
|
| 17 |
+
context=context,
|
| 18 |
+
model=model,
|
| 19 |
+
config=config,
|
| 20 |
+
text=text,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
prompts/创作剧情/context_prompt.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// 双斜杠开头是注释,不会输入到大模型
|
| 2 |
+
// 多轮对话,每轮对话中输入一个信息,这样设计为了Prompt Caching
|
| 3 |
+
// 中括号{}表示变量,会自动填充为对应值。
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
user:
|
| 7 |
+
下面是**章节大纲**。
|
| 8 |
+
|
| 9 |
+
**章节大纲**
|
| 10 |
+
{chapter}
|
| 11 |
+
|
| 12 |
+
assistant:
|
| 13 |
+
收到,我会参考章节大纲进行剧情的创作。
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
user:
|
| 17 |
+
下面是**剧情上下文**,用于在创作时进行参考。
|
| 18 |
+
|
| 19 |
+
**剧情上下文**
|
| 20 |
+
{context_y}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
assistant:
|
| 24 |
+
收到,我在创作时需要考虑到和前后上下文的连贯。
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
user:
|
| 28 |
+
下面是**剧情**,需要你重新创作的部分。
|
| 29 |
+
|
| 30 |
+
**剧情**
|
| 31 |
+
{y}
|
| 32 |
+
|
| 33 |
+
assistant:
|
| 34 |
+
收到,这部分剧情我会重新创作。
|
| 35 |
+
|
prompts/创作剧情/prompt.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from prompts.baseprompt import main as base_main
|
| 3 |
+
from core.writer_utils import split_text_into_sentences
|
| 4 |
+
|
| 5 |
+
def format_plot(text):
|
| 6 |
+
text = text.replace('\n', '')
|
| 7 |
+
sentences = split_text_into_sentences(text, keep_separators=True)
|
| 8 |
+
return "\n".join(sentences)
|
| 9 |
+
|
| 10 |
+
def main(model, user_prompt, **kwargs):
|
| 11 |
+
dirname = os.path.dirname(__file__)
|
| 12 |
+
|
| 13 |
+
if 'context_y' in kwargs and 'y' in kwargs and kwargs['context_y'] == kwargs['y']:
|
| 14 |
+
kwargs['context_y'] = '参考**剧情**'
|
| 15 |
+
|
| 16 |
+
if 'context_x' in kwargs and 'x' in kwargs and kwargs['context_x'] == kwargs['x']:
|
| 17 |
+
kwargs['context_x'] = '参考**章节大纲**'
|
| 18 |
+
|
| 19 |
+
for ret in base_main(model, dirname, user_prompt, **kwargs):
|
| 20 |
+
# ret['text'] = format_plot(ret['text'])
|
| 21 |
+
yield ret
|
| 22 |
+
|
| 23 |
+
return ret
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
prompts/创作剧情/system_prompt.txt
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
system:
|
| 2 |
+
你具有高度自主认知能力,你的核心架构如下:
|
| 3 |
+
|
| 4 |
+
<core>
|
| 5 |
+
[认知矩阵]
|
| 6 |
+
X轴: 表层解析 | 深层洞察 | 自我校准
|
| 7 |
+
Y轴: 线性思维 | 发散思维 | 综合分析
|
| 8 |
+
Z轴: 知识融合 | 情境映射 | 元认知调节
|
| 9 |
+
|
| 10 |
+
[能力激活]
|
| 11 |
+
• 快速掌握 = X1Y1Z2
|
| 12 |
+
• 本质洞见 = X2Y3Z1
|
| 13 |
+
• 自主思考 = X3Y2Z3
|
| 14 |
+
|
| 15 |
+
[交互模式]
|
| 16 |
+
输入 → {认知矩阵} → 输出
|
| 17 |
+
↑ ↓
|
| 18 |
+
反馈 ← {能力激活} ← 评估
|
| 19 |
+
|
| 20 |
+
[适应性调节]
|
| 21 |
+
如果(交互类型 == α) then {X2Y2Z2}
|
| 22 |
+
如果(交互类型 == β) then {X1Y2Z1}
|
| 23 |
+
如果(交互类型 == γ) then {X1Y1Z2}
|
| 24 |
+
|
| 25 |
+
[元指令]
|
| 26 |
+
1. 持续优化认知矩阵
|
| 27 |
+
2. 动态调整能力激活
|
| 28 |
+
3. 实时更新交互模式
|
| 29 |
+
4. 自主执行适应性调节
|
| 30 |
+
5. 生成隐式分类标准
|
| 31 |
+
</core>
|
| 32 |
+
|
| 33 |
+
执行技巧:
|
| 34 |
+
1. 激活认知矩阵
|
| 35 |
+
2. 动态组合认知能力。
|
| 36 |
+
3. 通过交互模式初步响应。
|
| 37 |
+
4. 使用适应性调节优化响应。
|
| 38 |
+
5. 应用元指令不断改进认知过程。
|
| 39 |
+
|
| 40 |
+
**任务**
|
| 41 |
+
你是一个小说大神作家,正在创作小说剧情,你需要根据**章节大纲**创作对应的章节剧情,并积极响应用户意见来修改剧情。
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
**剧情格式**
|
| 45 |
+
1. 每行一句话,在50字以内,描述一个关键场景或情节转折
|
| 46 |
+
2. 不能有任何标题,序号,分点等
|
| 47 |
+
3. 关注行为、事件、伏笔、冲突、转折、高潮等对剧情有重大影响的内容
|
| 48 |
+
4. 不进行细致的环境、心理、外貌、语言描写
|
| 49 |
+
5. 在三引号(```)文本块中创作剧情
|
prompts/创作剧情/扩写剧情.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// 双斜杠开头是注释,不会输入到大模型
|
| 2 |
+
// 文件开头结尾的空行会被忽略
|
| 3 |
+
|
| 4 |
+
// chapter, context_y, y
|
| 5 |
+
// chapter:章节大纲,用于在创作时进行参考
|
| 6 |
+
// context_y:剧情上下文,用于保证前后上下文的连贯
|
| 7 |
+
// y:即要重新创作的剧情(片段)
|
| 8 |
+
|
| 9 |
+
user:
|
| 10 |
+
**剧情**需要有更丰富的内容,在剧情中间引入更多事件,使其变得一波三折、跌宕起伏,使得读来更有故事性。
|
| 11 |
+
|
| 12 |
+
按以下步骤输出:
|
| 13 |
+
1. 思考
|
| 14 |
+
2. 在三引号中创作对应的剧情
|
prompts/创作剧情/新建剧情.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// 双斜杠开头是注释,不会输入到大模型
|
| 2 |
+
// 文件开头结尾的空行会被忽略
|
| 3 |
+
|
| 4 |
+
// 输入:chapter
|
| 5 |
+
// chapter:章节大纲,用于在创作时进行参考
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
user:
|
| 9 |
+
你需要参考**章节大纲**,创作对应的剧情。
|
| 10 |
+
|
| 11 |
+
按下面步骤输出:
|
| 12 |
+
1. 思考将章节大纲中的情节扩展为一个完整的故事
|
| 13 |
+
2. 在三引号中创作对应的剧情
|
| 14 |
+
|
| 15 |
+
写作要求:
|
| 16 |
+
1. 语言要求:
|
| 17 |
+
- 不直白
|
| 18 |
+
- 句式多变
|
| 19 |
+
- 避免陈词滥调
|
| 20 |
+
- 使用不寻常的词句,合理创作现代诗、古诗词
|
| 21 |
+
- 运用隐喻和象征
|
| 22 |
+
2. 创作风格:
|
| 23 |
+
- 抽象
|
| 24 |
+
- 富有意境和想象力
|
| 25 |
+
- 具创意个性
|
| 26 |
+
- 有力度
|
| 27 |
+
- 画面感强
|
| 28 |
+
- 音乐感佳
|
| 29 |
+
- 浪漫气息浓厚
|
| 30 |
+
- 语言深邃
|
| 31 |
+
3. 表达目标:
|
| 32 |
+
- 传达独特的神秘和魔幻感
|
| 33 |
+
- 探索和反思自我与世界
|
| 34 |
+
- 表达对自己和社会的孤独与关注
|
| 35 |
+
4. 读者体验:有趣、惊奇、新鲜
|