deeme commited on
Commit
217acfe
·
verified ·
1 Parent(s): 729847c

Upload 111 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env +86 -0
  2. Dockerfile +8 -0
  3. README.md +6 -5
  4. app.py +341 -0
  5. backend_utils.py +22 -0
  6. config.py +81 -0
  7. core/__init__.py +1 -0
  8. core/backend.py +218 -0
  9. core/diff_utils.py +173 -0
  10. core/draft_writer.py +46 -0
  11. core/frontend.py +435 -0
  12. core/frontend_copy.py +35 -0
  13. core/frontend_setting.py +345 -0
  14. core/frontend_utils.py +333 -0
  15. core/outline_writer.py +88 -0
  16. core/parser_utils.py +32 -0
  17. core/plot_writer.py +49 -0
  18. core/summary_novel.py +94 -0
  19. core/writer.py +533 -0
  20. core/writer_utils.py +216 -0
  21. custom/根据提纲创作正文/天蚕土豆风格.txt +14 -0
  22. custom/根据提纲创作正文/对草稿进行润色.txt +7 -0
  23. healthcheck.py +24 -0
  24. llm_api/__init__.py +109 -0
  25. llm_api/baidu_api.py +48 -0
  26. llm_api/chat_messages.py +116 -0
  27. llm_api/doubao_api.py +53 -0
  28. llm_api/model_prices.json +0 -0
  29. llm_api/mongodb_cache.py +127 -0
  30. llm_api/mongodb_cost.py +121 -0
  31. llm_api/mongodb_init.py +7 -0
  32. llm_api/openai_api.py +67 -0
  33. llm_api/sparkai_api.py +66 -0
  34. llm_api/zhipuai_api.py +54 -0
  35. prompts/baseprompt.py +105 -0
  36. prompts/chat_utils.py +40 -0
  37. prompts/common_parser.py +21 -0
  38. prompts/idea-examples.yaml +9 -0
  39. prompts/pf_parse_chat.py +94 -0
  40. prompts/prompt_utils.py +128 -0
  41. prompts/test_format_plot.yaml +28 -0
  42. prompts/test_prompt.py +22 -0
  43. prompts/text-plot-examples.yaml +227 -0
  44. prompts/tool_parser.py +39 -0
  45. prompts/tool_polish.py +23 -0
  46. prompts/创作剧情/context_prompt.txt +35 -0
  47. prompts/创作剧情/prompt.py +26 -0
  48. prompts/创作剧情/system_prompt.txt +49 -0
  49. prompts/创作剧情/扩写剧情.txt +14 -0
  50. prompts/创作剧情/新建剧情.txt +35 -0
.env ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Thread Configuration - 线程配置
2
+ # 生成时采用的最大线程数,5-10即可。会带来成倍的API调用费用,不要设置过高!
3
+ MAX_THREAD_NUM=1
4
+
5
+
6
+ # Server Configuration - Docker服务配置
7
+ # 前端服务端口
8
+ FRONTEND_PORT=80
9
+ # 后端服务端口
10
+ BACKEND_PORT=7869
11
+ # 后端服务监听地址
12
+ BACKEND_HOST=0.0.0.0
13
+ # Gunicorn工作进程数
14
+ WORKERS=4
15
+ # 每个工作进程的线程数
16
+ THREADS=2
17
+ # 请求超时时间(秒)
18
+ TIMEOUT=120
19
+
20
+ # 是否启用在线演示
21
+ # 不用设置,默认不启用
22
+ ENABLE_ONLINE_DEMO=False
23
+
24
+ # Backend Configuration - 后端配置
25
+ # 导入小说时,最大的处理长度,超出该长度的文本不会进行处理,可以考虑增加
26
+ MAX_NOVEL_SUMMARY_LENGTH=20000
27
+
28
+
29
+ # MongoDB Configuration - MongoDB数据库配置
30
+ # 安装了MongoDB才需要配置,否则不用改动
31
+ # 是否启用MongoDB,启用后下面配置才有效
32
+ ENABLE_MONGODB=false
33
+ # MongoDB连接地址,使用host.docker.internal访问宿主机MongoDB
34
+ MONGODB_URI=mongodb://host.docker.internal:27017/
35
+ # MongoDB数据库名称
36
+ MONGODB_DB_NAME=llm_api
37
+ # 是否启用API缓存
38
+ ENABLE_MONGODB_CACHE=true
39
+ # 缓存命中后重放速度倍率
40
+ CACHE_REPLAY_SPEED=2
41
+ # 缓存命中后最大延迟时间(秒)
42
+ CACHE_REPLAY_MAX_DELAY=5
43
+
44
+
45
+ # API Cost Limits - API费用限制设置,需要依赖于MongoDB
46
+ # 每小时费用上限(人民币)
47
+ API_HOURLY_LIMIT_RMB=100
48
+ # 每天费用上限(人民币)
49
+ API_DAILY_LIMIT_RMB=500
50
+ # 美元兑人民币汇率
51
+ API_USD_TO_RMB_RATE=7
52
+
53
+
54
+ # Wenxin API Settings - 文心API配置
55
+ # 文心API的AK,获取地址:https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
56
+ WENXIN_AK=
57
+ WENXIN_SK=
58
+ WENXIN_AVAILABLE_MODELS=ERNIE-Novel-8K,ERNIE-4.0-8K,ERNIE-3.5-8K
59
+
60
+ # Doubao API Settings - 豆包API配置
61
+ # DOUBAO_ENDPOINT_IDS和DOUBAO_AVAILABLE_MODELS一一对应,有几个模型就对应几个endpoint_id,这是豆包强制要求的
62
+ # 你可以自行设置DOUBAO_AVAILABLE_MODELS,不一定非要采用下面的
63
+ DOUBAO_API_KEY=
64
+ DOUBAO_ENDPOINT_IDS=
65
+ DOUBAO_AVAILABLE_MODELS=doubao-pro-32k,doubao-lite-32k
66
+
67
+ # GPT API Settings - GPT API配置
68
+
69
+ GPT_AVAILABLE_MODELS=deepseek-r1,gemini-2.0-pro-exp-02-05,lmsys/claude-3-5-sonnet-20241022,windsurf/claude-3-5-sonnet,claude-3-5-sonnet-20240620,o3-mini,gpt-4-turbo-2024-04-09,gemini-2.0-flash-thinking-exp
70
+
71
+ # Local Model Settings - 本地模型配置
72
+ # 本地模型配置需要把下面的localhost替换为host.docker.internal,把8000替换为你的本地大模型服务端口
73
+ # 把local-key替换为你的本地大模型服务API_KEY,把local-model-1替换为你的本地大模型服务模型名
74
+ # 并且docker启动方式有变化,详细参考readme
75
+ LOCAL_BASE_URL=http://localhost:8000/v1
76
+ LOCAL_API_KEY=local-key
77
+ LOCAL_AVAILABLE_MODELS=local-model-1
78
+
79
+ # Zhipuai API Settings - 智谱AI配置
80
+ ZHIPUAI_API_KEY=
81
+ ZHIPUAI_AVAILABLE_MODELS=glm-4-air,glm-4-flashx
82
+
83
+ # Default Model Settings - 默认模型设置
84
+ # 例如:wenxin/ERNIE-Novel-8K, doubao/doubao-pro-32k, gpt/gpt-4o-mini, local/local-model-1
85
+ DEFAULT_MAIN_MODEL=gpt/claude-3-5-sonnet-20240620
86
+ DEFAULT_SUB_MODEL=gpt/gemini-2.0-pro-exp-02-05
Dockerfile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+ RUN pip install -r requirements.txt
7
+
8
+ CMD ["python app.py"]
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Long
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: long
3
+ emoji: 👀
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
+ license: gpl-3.0
9
+ app_port: 7869
10
  ---
11
 
 
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+ from flask import Flask, request, Response, jsonify
5
+ from flask_cors import CORS
6
+ app = Flask(__name__)
7
+ CORS(app)
8
+
9
+ import sys
10
+ import os
11
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+
14
+ from prompts.baseprompt import clean_txt_content, load_prompt
15
+
16
+ from core.writer_utils import KeyPointMsg
17
+ from core.draft_writer import DraftWriter
18
+ from core.plot_writer import PlotWriter
19
+ from core.outline_writer import OutlineWriter
20
+
21
+ from setting import setting_bp
22
+ from summary import process_novel
23
+ from backend_utils import get_model_config_from_provider_model
24
+ from config import MAX_NOVEL_SUMMARY_LENGTH, MAX_THREAD_NUM, ENABLE_ONLINE_DEMO
25
+
26
+
27
+ app.register_blueprint(setting_bp)
28
+
29
+ # 添加配置
30
+ BACKEND_HOST = os.environ.get('BACKEND_HOST', '0.0.0.0')
31
+ BACKEND_PORT = int(os.environ.get('BACKEND_PORT', 7869))
32
+
33
+
34
+ @app.route('/health', methods=['GET'])
35
+ def health_check():
36
+ return jsonify({
37
+ 'status': 'healthy',
38
+ 'timestamp': int(time.time())
39
+ }), 200
40
+
41
+
42
+ def load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num) -> DraftWriter:
43
+ kwargs = dict(
44
+ xy_pairs=chunk_list,
45
+ model=get_model_config_from_provider_model(main_model),
46
+ sub_model=get_model_config_from_provider_model(sub_model),
47
+ )
48
+
49
+ kwargs['x_chunk_length'] = x_chunk_length
50
+ kwargs['y_chunk_length'] = y_chunk_length
51
+ kwargs['max_thread_num'] = max_thread_num
52
+ match writer_mode:
53
+ case 'draft':
54
+ kwargs['global_context'] = {}
55
+ novel_writer = DraftWriter(**kwargs)
56
+ case 'outline':
57
+ kwargs['global_context'] = {'summary': global_context}
58
+ novel_writer = OutlineWriter(**kwargs)
59
+ case 'plot':
60
+ kwargs['global_context'] = {'chapter': global_context}
61
+ novel_writer = PlotWriter(**kwargs)
62
+ case _:
63
+ raise ValueError(f"unknown writer: {writer_mode}")
64
+
65
+ return novel_writer
66
+
67
+
68
+
69
+
70
+
71
+ prompt_names = dict(
72
+ outline = ['新建章节', '扩写章节', '润色章节'],
73
+ plot = ['新建剧情', '扩写剧情', '润色剧情'],
74
+ draft = ['新建正文', '扩写正文', '润色正文'],
75
+ )
76
+
77
+ prompt_dirname = dict(
78
+ outline = 'prompts/创作章节',
79
+ plot = 'prompts/创作剧情',
80
+ draft = 'prompts/创作正文',
81
+ )
82
+
83
+
84
+ PROMPTS = {}
85
+ for type_name, dirname in prompt_dirname.items():
86
+ PROMPTS[type_name] = {'prompt_names': prompt_names[type_name]}
87
+ for name in prompt_names[type_name]:
88
+ content = clean_txt_content(load_prompt(dirname, name))
89
+ if content.startswith("user:\n"):
90
+ content = content[len("user:\n"):]
91
+ PROMPTS[type_name][name] = {'content': content}
92
+
93
+
94
+ @app.route('/prompts', methods=['GET'])
95
+ def get_prompts():
96
+ return jsonify(PROMPTS)
97
+
98
+ def get_delta_chunks(prev_chunks, curr_chunks):
99
+ """Calculate delta between previous and current chunks"""
100
+ if not prev_chunks or len(prev_chunks) != len(curr_chunks):
101
+ return "init", curr_chunks
102
+
103
+ # Check if all strings in current chunks start with their corresponding previous strings
104
+ is_delta = True
105
+ for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
106
+ if len(prev_chunk) != len(curr_chunk):
107
+ is_delta = False
108
+ break
109
+ for prev_str, curr_str in zip(prev_chunk, curr_chunk):
110
+ if not curr_str.startswith(prev_str):
111
+ is_delta = False
112
+ break
113
+ if not is_delta:
114
+ break
115
+
116
+ if not is_delta:
117
+ return "init", curr_chunks
118
+
119
+ # Calculate deltas
120
+ delta_chunks = []
121
+ for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
122
+ delta_chunk = []
123
+ for prev_str, curr_str in zip(prev_chunk, curr_chunk):
124
+ delta_str = curr_str[len(prev_str):]
125
+ delta_chunk.append(delta_str)
126
+ delta_chunks.append(delta_chunk)
127
+
128
+ return "delta", delta_chunks
129
+
130
+
131
+ def call_write(writer_mode, chunk_list, global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
132
+ if ENABLE_ONLINE_DEMO:
133
+ if max_thread_num > MAX_THREAD_NUM:
134
+ raise Exception("在线Demo模型下,最大线程数不能超过" + str(MAX_THREAD_NUM) + "!")
135
+
136
+ # 输入的chunk_list中每个chunk需要加上换行,除了最后一个chunk(因为是从页面中各个chunk传来的)
137
+ chunk_list = [[e.strip() + ('\n' if e.strip() and rowi != len(chunk_list)-1 else '') for e in row] for rowi, row in enumerate(chunk_list)]
138
+
139
+ prev_chunks = None
140
+ def delta_wrapper(chunk_list, done=False, msg=None):
141
+ # 返回的chunk_list中每个chunk需要去掉换行
142
+ chunk_list = [[e.strip() for e in row] for row in chunk_list]
143
+
144
+ nonlocal prev_chunks
145
+ if prev_chunks is None:
146
+ prev_chunks = chunk_list
147
+ return {
148
+ "done": done,
149
+ "chunk_type": "init",
150
+ "chunk_list": chunk_list,
151
+ "msg": msg
152
+ }
153
+ else:
154
+ chunk_type, new_chunks = get_delta_chunks(prev_chunks, chunk_list)
155
+ prev_chunks = chunk_list
156
+ return {
157
+ "done": done,
158
+ "chunk_type": chunk_type,
159
+ "chunk_list": new_chunks,
160
+ "msg": msg
161
+ }
162
+
163
+ novel_writer = load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
164
+
165
+
166
+ # draft需要映射,所以进行初始划分
167
+ if writer_mode == 'draft':
168
+ target_chunk = novel_writer.get_chunk(pair_span=chunk_span)
169
+ new_target_chunk = novel_writer.map_text_wo_llm(target_chunk)
170
+ novel_writer.apply_chunks([target_chunk], [new_target_chunk])
171
+ chunk_span = novel_writer.get_chunk_pair_span(new_target_chunk)
172
+
173
+ init_novel_writer = load_novel_writer(writer_mode, list(novel_writer.xy_pairs), global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
174
+
175
+ # TODO: writer.write 应该保证无论什么prompt,都能够同时适应y为空和y有值地情况
176
+ # 换句话说,就是虽然可以单列出一个"新建正文",但用扩写正文也能实现同样的效果。
177
+ generator = novel_writer.write(prompt_content, pair_span=chunk_span)
178
+
179
+ prompt_outputs = []
180
+ last_yield_time = time.time() # Initialize the last yield time
181
+
182
+ prompt_name = ''
183
+ for kp_msg in generator:
184
+ if isinstance(kp_msg, KeyPointMsg):
185
+ # 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
186
+ prompt_name = kp_msg.prompt_name
187
+ continue
188
+ else:
189
+ chunk_list = kp_msg
190
+
191
+ current_cost = 0
192
+ currency_symbol = ''
193
+ current_model = ''
194
+ data_chunks = []
195
+ prompt_outputs.clear()
196
+ for e in chunk_list:
197
+ if e is None: continue # e为None说明该chunk还未处理
198
+ output, chunk = e
199
+ if output is None: continue # output为None说明该chunk未yield就return,说明未调用llm
200
+ prompt_outputs.append(output)
201
+ current_text = ""
202
+ current_model = output['response_msgs'].model
203
+ current_cost += output['response_msgs'].cost
204
+ currency_symbol = output['response_msgs'].currency_symbol
205
+ if 'plot2text' in output:
206
+ current_text += f"正在建立映射关系..." + '\n'
207
+ else:
208
+ current_text = output['text']
209
+ data_chunks.append((chunk.x_chunk, chunk.y_chunk, current_text))
210
+
211
+ if only_prompt:
212
+ yield {'prompts': [e['response_msgs'] for e in prompt_outputs]}
213
+ return
214
+
215
+ current_time = time.time()
216
+ if current_time - last_yield_time >= 0.2: # Check if 0.2 seconds have passed
217
+ yield delta_wrapper(data_chunks, done=False, msg=f"正在 {prompt_name} ({len(prompt_outputs)} / {len(chunk_list)})" + f" 模型:{current_model} 花费:{current_cost:.5f}{currency_symbol}" if current_model else '')
218
+ last_yield_time = current_time # Update the last yield time
219
+
220
+ # 这里是计算出一个编辑上的更改,方便前端显示,后续diff功能将不由writer提供,因为这是为了显示的要求
221
+ data_chunks = init_novel_writer.diff_to(novel_writer, pair_span=chunk_span)
222
+
223
+ yield delta_wrapper(data_chunks, done=True, msg='创作完成!')
224
+
225
+
226
+ @app.route('/write', methods=['POST'])
227
+ def write():
228
+ data = request.json
229
+ writer_mode = data['writer_mode']
230
+ chunk_list = data['chunk_list']
231
+ chunk_span = data['chunk_span']
232
+ prompt_content = data['prompt_content']
233
+ x_chunk_length = data['x_chunk_length']
234
+ y_chunk_length = data['y_chunk_length']
235
+ main_model = data['main_model']
236
+ sub_model = data['sub_model']
237
+ global_context = data['global_context']
238
+ only_prompt = data['only_prompt']
239
+
240
+ # Update settings if provided
241
+ if 'settings' in data:
242
+ max_thread_num = data['settings']['MAX_THREAD_NUM']
243
+
244
+ # Generate unique stream ID
245
+ stream_id = str(time.time())
246
+ active_streams[stream_id] = True
247
+
248
+ def generate():
249
+ try:
250
+ # Send stream ID to client
251
+ yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
252
+
253
+ for result in call_write(writer_mode, list(chunk_list), global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
254
+ if not active_streams.get(stream_id, False):
255
+ # Stream was stopped by client
256
+ print(f"Stream was stopped by client: {stream_id}")
257
+ return
258
+
259
+ yield f"data: {json.dumps(result)}\n\n"
260
+ except Exception as e:
261
+ error_msg = f"创作出错:\n{str(e)}"
262
+ error_chunk_list = [[*e[:2], error_msg] for e in chunk_list[chunk_span[0]:chunk_span[1]]]
263
+
264
+ error_data = {
265
+ "done": True,
266
+ "chunk_type": "init",
267
+ "chunk_list": error_chunk_list
268
+ }
269
+ yield f"data: {json.dumps(error_data)}\n\n"
270
+ finally:
271
+ # Clean up stream tracking
272
+ if stream_id in active_streams:
273
+ del active_streams[stream_id]
274
+
275
+ return Response(generate(), mimetype='text/event-stream')
276
+
277
+
278
+ @app.route('/summary', methods=['POST'])
279
+ def process_novel_text():
280
+ data = request.json
281
+ content = data['content']
282
+ novel_name = data['novel_name']
283
+
284
+ # Generate unique stream ID
285
+ stream_id = str(time.time())
286
+ active_streams[stream_id] = True
287
+
288
+ def generate():
289
+ try:
290
+ yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
291
+
292
+ main_model = get_model_config_from_provider_model(data['main_model'])
293
+ sub_model = get_model_config_from_provider_model(data['sub_model'])
294
+ max_novel_summary_length = data['settings']['MAX_NOVEL_SUMMARY_LENGTH']
295
+ max_thread_num = data['settings']['MAX_THREAD_NUM']
296
+ last_yield_time = 0
297
+ for result in process_novel(content, novel_name, main_model, sub_model, max_novel_summary_length, max_thread_num):
298
+ if not active_streams.get(stream_id, False):
299
+ # Stream was stopped by client
300
+ print(f"Stream was stopped by client: {stream_id}")
301
+ return
302
+
303
+ current_time = time.time()
304
+ yield_value = f"data: {json.dumps(result)}\n\n"
305
+ if current_time - last_yield_time >= 0.2:
306
+ last_yield_time = current_time
307
+ yield yield_value
308
+ if current_time - last_yield_time < 0.2:
309
+ # Save last yield to yaml file
310
+ import yaml
311
+ result_dict = json.loads(yield_value.replace('data: ', '').strip())
312
+ with open('tmp.yaml', 'w', encoding='utf-8') as f:
313
+ yaml.dump(result_dict, f, allow_unicode=True)
314
+
315
+ yield yield_value # Ensure last yield is returned
316
+
317
+ except Exception as e:
318
+ error_data = {
319
+ "progress_msg": f"处理出错:{str(e)}",
320
+ }
321
+ yield f"data: {json.dumps(error_data)}\n\n"
322
+ finally:
323
+ # Clean up stream tracking
324
+ if stream_id in active_streams:
325
+ del active_streams[stream_id]
326
+
327
+ return Response(generate(), mimetype='text/event-stream')
328
+
329
+ # Dictionary to track active streams
330
+ active_streams = {}
331
+
332
+ @app.route('/stop_stream', methods=['POST'])
333
+ def stop_stream():
334
+ data = request.json
335
+ stream_id = data.get('stream_id')
336
+ if stream_id in active_streams:
337
+ active_streams[stream_id] = False
338
+ return jsonify({'success': True})
339
+
340
+ if __name__ == '__main__':
341
+ app.run(host=BACKEND_HOST, port=BACKEND_PORT, debug=False)
backend_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llm_api import ModelConfig
2
+
3
+ def get_model_config_from_provider_model(provider_model):
4
+ from config import API_SETTINGS
5
+ provider, model = provider_model.split('/', 1)
6
+ provider_config = API_SETTINGS[provider]
7
+
8
+ if provider == 'doubao':
9
+ # Get the index of the model in available_models to find corresponding endpoint_id
10
+ model_index = provider_config['available_models'].index(model)
11
+ endpoint_id = provider_config['endpoint_ids'][model_index] if model_index < len(provider_config['endpoint_ids']) else ''
12
+ model_config = {**provider_config, 'model': model, 'endpoint_id': endpoint_id}
13
+ else:
14
+ model_config = {**provider_config, 'model': model}
15
+
16
+ # Remove lists from config before creating ModelConfig
17
+ if 'available_models' in model_config:
18
+ del model_config['available_models']
19
+ if 'endpoint_ids' in model_config:
20
+ del model_config['endpoint_ids']
21
+
22
+ return ModelConfig(**model_config)
config.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import dotenv_values, load_dotenv
3
+
4
+ print("Loading .env file...")
5
+ env_path = os.path.join(os.path.dirname(__file__), '.env')
6
+ if os.path.exists(env_path):
7
+ env_dict = dotenv_values(env_path)
8
+
9
+ print("Environment variables to be loaded:")
10
+ for key, value in env_dict.items():
11
+ print(f"{key}={value}")
12
+ print("-" * 50)
13
+
14
+ os.environ.update(env_dict)
15
+ print(f"Loaded environment variables from: {env_path}")
16
+ else:
17
+ print("Warning: .env file not found")
18
+
19
+
20
+ # Thread Configuration
21
+ MAX_THREAD_NUM = int(os.getenv('MAX_THREAD_NUM', 5))
22
+
23
+
24
+ MAX_NOVEL_SUMMARY_LENGTH = int(os.getenv('MAX_NOVEL_SUMMARY_LENGTH', 20000))
25
+
26
+ # MongoDB Configuration
27
+ ENABLE_MONOGODB = os.getenv('ENABLE_MONGODB', 'false').lower() == 'true'
28
+ MONGODB_URI = os.getenv('MONGODB_URI', 'mongodb://127.0.0.1:27017/')
29
+ MONOGODB_DB_NAME = os.getenv('MONGODB_DB_NAME', 'llm_api')
30
+ ENABLE_MONOGODB_CACHE = os.getenv('ENABLE_MONGODB_CACHE', 'true').lower() == 'true'
31
+ CACHE_REPLAY_SPEED = float(os.getenv('CACHE_REPLAY_SPEED', 2))
32
+ CACHE_REPLAY_MAX_DELAY = float(os.getenv('CACHE_REPLAY_MAX_DELAY', 5))
33
+
34
+ # API Cost Limits
35
+ API_COST_LIMITS = {
36
+ 'HOURLY_LIMIT_RMB': float(os.getenv('API_HOURLY_LIMIT_RMB', 100)),
37
+ 'DAILY_LIMIT_RMB': float(os.getenv('API_DAILY_LIMIT_RMB', 500)),
38
+ 'USD_TO_RMB_RATE': float(os.getenv('API_USD_TO_RMB_RATE', 7))
39
+ }
40
+
41
+ # API Settings
42
+ API_SETTINGS = {
43
+ 'wenxin': {
44
+ 'ak': os.getenv('WENXIN_AK', ''),
45
+ 'sk': os.getenv('WENXIN_SK', ''),
46
+ 'available_models': os.getenv('WENXIN_AVAILABLE_MODELS', '').split(','),
47
+ 'max_tokens': 4096,
48
+ },
49
+ 'doubao': {
50
+ 'api_key': os.getenv('DOUBAO_API_KEY', ''),
51
+ 'endpoint_ids': os.getenv('DOUBAO_ENDPOINT_IDS', '').split(','),
52
+ 'available_models': os.getenv('DOUBAO_AVAILABLE_MODELS', '').split(','),
53
+ 'max_tokens': 4096,
54
+ },
55
+ 'gpt': {
56
+ 'base_url': os.getenv('GPT_BASE_URL', ''),
57
+ 'api_key': os.getenv('GPT_API_KEY', ''),
58
+ 'proxies': os.getenv('GPT_PROXIES', ''),
59
+ 'available_models': os.getenv('GPT_AVAILABLE_MODELS', '').split(','),
60
+ 'max_tokens': 4096,
61
+ },
62
+ 'zhipuai': {
63
+ 'api_key': os.getenv('ZHIPUAI_API_KEY', ''),
64
+ 'available_models': os.getenv('ZHIPUAI_AVAILABLE_MODELS', '').split(','),
65
+ 'max_tokens': 4096,
66
+ },
67
+ 'local': {
68
+ 'base_url': os.getenv('LOCAL_BASE_URL', ''),
69
+ 'api_key': os.getenv('LOCAL_API_KEY', ''),
70
+ 'available_models': os.getenv('LOCAL_AVAILABLE_MODELS', '').split(','),
71
+ 'max_tokens': 4096,
72
+ }
73
+ }
74
+
75
+ for model in API_SETTINGS.values():
76
+ model['available_models'] = [e.strip() for e in model['available_models']]
77
+
78
+ DEFAULT_MAIN_MODEL = os.getenv('DEFAULT_MAIN_MODEL', 'wenxin/ERNIE-Novel-8K')
79
+ DEFAULT_SUB_MODEL = os.getenv('DEFAULT_SUB_MODEL', 'wenxin/ERNIE-3.5-8K')
80
+
81
+ ENABLE_ONLINE_DEMO = os.getenv('ENABLE_ONLINE_DEMO', 'false').lower() == 'true'
core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # core 模块为LongNovelGPT到2.0版本之间的过渡,在core模块中进行一些新功能和设计的尝试
core/backend.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import importlib
3
+ from core.draft_writer import DraftWriter
4
+ from core.plot_writer import PlotWriter
5
+ from core.outline_writer import OutlineWriter
6
+ from core.writer_utils import KeyPointMsg
7
+ from core.diff_utils import match_span_by_char
8
+ import copy
9
+ import types
10
+
11
+ def load_novel_writer(writer, setting) -> DraftWriter:
12
+ current_w_name = writer['current_w']
13
+ current_w = writer[current_w_name]
14
+
15
+ kwargs = dict(
16
+ xy_pairs=list(current_w.get('xy_pairs', [['', '']])),
17
+ model=setting['model'],
18
+ sub_model=setting['sub_model'],
19
+ )
20
+
21
+ kwargs['x_chunk_length'] = current_w['x_chunk_length']
22
+ kwargs['y_chunk_length'] = current_w['y_chunk_length']
23
+
24
+ match current_w_name:
25
+ case 'draft_w':
26
+ novel_writer = DraftWriter(**kwargs)
27
+ case 'outline_w':
28
+ novel_writer = OutlineWriter(**kwargs)
29
+ case 'chapters_w' | 'plot_w':
30
+ novel_writer = PlotWriter(**kwargs)
31
+ case _:
32
+ raise ValueError(f"unknown writer: {current_w_name}")
33
+
34
+ return novel_writer
35
+
36
+ def dump_novel_writer(writer, novel_writer, apply_chunks={}, cost=0, currency_symbol='¥'):
37
+ new_writer = copy.deepcopy(writer) # TODO: dump从设计角度上来说,不应该更改原有的writer,但是在此处copy可能更耗时
38
+
39
+ current_w_name = new_writer['current_w']
40
+ current_w = new_writer[current_w_name]
41
+
42
+ # if current_w_name == 'draft_w':
43
+ # assert isinstance(novel_writer, DraftWriter), "draft_w需要传入DraftWriter"
44
+
45
+ current_w['xy_pairs'] = list(novel_writer.xy_pairs)
46
+
47
+ current_w['current_cost'] = cost
48
+ current_w['currency_symbol'] = currency_symbol
49
+ #current_w['total_cost'] += current_w['current_cost']
50
+
51
+ current_w['apply_chunks'] = apply_chunks
52
+
53
+ return new_writer
54
+
55
+ def call_write_long_novel(writer, setting):
56
+ writer = copy.deepcopy(writer)
57
+ progress = writer['progress']
58
+
59
+ if not progress or True:
60
+ progress = dict(
61
+ cur_op_i = progress['cur_op_i'] if progress else 0,
62
+ ops = [
63
+ {
64
+ 'before_eval': 'writer["current_w"] = "outline_w"',
65
+ 'eval': 'call_write(writer, setting, False, "构思全书的大致剧情,并将其以一个故事的形式写下来,只写大致情节。")',
66
+ 'title': '创作大纲',
67
+ 'subtitle': '生成大纲'
68
+ },
69
+ {
70
+ 'eval': 'call_accept(writer, setting)',
71
+ },
72
+ {
73
+ 'eval': 'call_write(writer, setting, True, "对整个情节进行重写,使其更加有故事性。")',
74
+ 'title': '创作大纲',
75
+ 'subtitle': '润色大纲',
76
+ },
77
+ {
78
+ 'eval': 'call_accept(writer, setting)',
79
+ },
80
+ # 下面是创作剧情
81
+ {
82
+ 'before_eval': 'init_chapters_w(writer)',
83
+ 'eval': 'call_write(writer, setting, False, "丰富其中的剧情细节。")',
84
+ 'title': '创作剧情',
85
+ 'subtitle': '生成剧情'
86
+ },
87
+ {
88
+ 'eval': 'call_accept(writer, setting)',
89
+ },
90
+ {
91
+ 'eval': 'call_write(writer, setting, True, "对情节进行重写,使其有更多的剧情细节,同时更加有具有故事性。")',
92
+ 'title': '创作剧情',
93
+ 'subtitle': '扩充剧情',
94
+ },
95
+ {
96
+ 'eval': 'call_accept(writer, setting)',
97
+ },
98
+ # 下面是创作正文
99
+ {
100
+ 'before_eval': 'init_draft_w(writer)',
101
+ 'eval': 'call_write(writer, setting, False, "创作的是正文,而不是剧情,需要像一个小说家那样去描写这个故事。")',
102
+ 'title': '创作正文',
103
+ 'subtitle': '生成正文'
104
+ },
105
+ {
106
+ 'eval': 'call_accept(writer, setting)',
107
+ },
108
+ {
109
+ 'eval': 'call_write(writer, setting, True, "润色正文")',
110
+ 'title': '创作正文',
111
+ 'subtitle': '润色正文'
112
+ },
113
+ {
114
+ 'eval': 'call_accept(writer, setting)',
115
+ }
116
+ ]
117
+ )
118
+
119
+ # TODO: 考虑在init_plot时就给到上下文,类似rewrite_plot
120
+
121
+ title, subtitle = '', ''
122
+ for op in progress['ops']:
123
+ if 'title' not in op:
124
+ op['title'], op['subtitle'] = title, subtitle
125
+ else:
126
+ title, subtitle = op['title'], op['subtitle']
127
+
128
+
129
+ writer['progress'] = progress
130
+ yield writer
131
+
132
+ while progress['cur_op_i'] < len(progress['ops']):
133
+ current_op = progress['ops'][progress['cur_op_i']]
134
+ if 'before_eval' in current_op:
135
+ exec(current_op['before_eval'])
136
+ writer = yield from eval(current_op['eval'])
137
+ progress = writer['progress']
138
+
139
+ progress['cur_op_i'] += 1
140
+ yield writer # 当cur_op_i有更新时,也就标志着yield的是一个“稳定版本”的writer_state
141
+
142
+ return writer
143
+
144
+ def match_quote_text(writer, setting, quote_text):
145
+ novel_writer = load_novel_writer(writer, setting)
146
+ y_text = novel_writer.y
147
+ quote_text_span, match_ratio = match_span_by_char(y_text, quote_text)
148
+ if match_ratio > 0.5:
149
+ aligned_span, _ = novel_writer.align_span(y_span=quote_text_span)
150
+ return aligned_span, y_text[aligned_span[0]:aligned_span[1]]
151
+ else:
152
+ return None, ''
153
+
154
+ # 这是后端函数,接受前端writer_state的copy做为输入
155
+ # 返回的是修改后的writer_state,注意yield的值一般被用于前端展示执行的过程和进度
156
+ # 只有return值才会被前端考虑用于writer_state的更新
157
+ def call_write(writer, setting, auto_write=False, suggestion=None):
158
+ novel_writer = load_novel_writer(writer, setting)
159
+
160
+ current_w = writer[writer['current_w']]
161
+ current_w['xy_pairs'] = list(novel_writer.xy_pairs)
162
+
163
+ quote_span = writer['quote_span']
164
+
165
+ if auto_write:
166
+ assert quote_span is None, "auto_write模式下,不能有quote_text"
167
+ generator = novel_writer.auto_write()
168
+ else:
169
+ # TODO: writer.write 应该保证无论什么prompt,都能够同时适应y为空和y有值地情况
170
+ # 换句话说,就是虽然可以单列出一个“新建正文”,但用扩写正文也能实现同样的效果。
171
+ generator = novel_writer.write(suggestion, y_span=quote_span)
172
+
173
+ prompt_outputs = []
174
+ for kp_msg in generator:
175
+ if isinstance(kp_msg, KeyPointMsg):
176
+ # 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
177
+ yield kp_msg
178
+ continue
179
+ else:
180
+ chunk_list = kp_msg
181
+
182
+ current_cost = 0
183
+ apply_chunks = []
184
+ prompt_outputs.clear()
185
+ for output, chunk in chunk_list:
186
+ prompt_outputs.append(output)
187
+ current_text = ""
188
+ current_cost += output['response_msgs'].cost
189
+ currency_symbol = output['response_msgs'].currency_symbol
190
+ cost_info = f"\n(预计花费:{output['response_msgs'].cost:.4f}{output['response_msgs'].currency_symbol})"
191
+ if 'plot2text' in output:
192
+ current_text += f"正在建立映射关系..." + cost_info + '\n'
193
+ else:
194
+ current_text += output['text'] + cost_info + '\n'
195
+ apply_chunks.append((chunk, 'y_chunk', current_text))
196
+
197
+ new_writer = dump_novel_writer(writer, novel_writer, apply_chunks=apply_chunks, cost=current_cost, currency_symbol=currency_symbol)
198
+ new_writer['prompt_outputs'] = prompt_outputs
199
+ yield new_writer
200
+
201
+ # 这里是计算出一个编辑上的更改,方便前端显示,后续diff功能将不由writer提供,因为这是为了显示的要求
202
+ apply_chunks = []
203
+ for chunk, key, value in load_novel_writer(writer, setting).diff_to(novel_writer):
204
+ apply_chunks.append((chunk, key, value))
205
+ writer[writer['current_w']]['apply_chunks'] = apply_chunks
206
+ writer['prompt_outputs'] = prompt_outputs
207
+ return writer
208
+
209
+ def call_accept(writer, setting):
210
+ current_w_name = writer['current_w']
211
+ current_w = writer[current_w_name]
212
+
213
+ novel_writer = load_novel_writer(writer, setting)
214
+ for chunk, key, text in current_w['apply_chunks']:
215
+ novel_writer.apply_chunk(chunk, key, text)
216
+
217
+ writer = dump_novel_writer(writer, novel_writer)
218
+ return writer
core/diff_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ from difflib import SequenceMatcher
3
+
4
+
5
+ def match_span_by_char(text, chunk):
6
+ # 用来存储从text中找到的符合匹配的行的span
7
+ spans = []
8
+
9
+ # 使用difflib来寻找最佳匹配行
10
+ matcher = difflib.SequenceMatcher(None, text, chunk)
11
+
12
+ # 获取匹配块信息
13
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
14
+ if tag == 'equal':
15
+ # 记录匹配行的起始和结束索引
16
+ spans.append((i1, i2))
17
+
18
+ if spans:
19
+ match_span = (spans[0][0], spans[-1][1])
20
+ match_ratio = sum(i2 - i1 for i1, i2 in spans) / len(chunk)
21
+ return match_span, match_ratio
22
+ else:
23
+ return None, 0
24
+
25
+ def match_sequences(a_list, b_list):
26
+ """
27
+ 匹配两个字符串列表,返回匹配的索引对
28
+
29
+ Args:
30
+ a_list: 第一个字符串列表
31
+ b_list: 第二个字符串列表
32
+
33
+ Returns:
34
+ list[((l,r), (j,k))]: 匹配的索引对列表,
35
+ 其中(l,r)表示a_list的起始和结束索引,(j,k)表示b_list的起始和结束索引
36
+ """
37
+ m, n = len(a_list) - 1, len(b_list) - 1
38
+ matches = []
39
+ i = j = 0
40
+
41
+ while i < m and j < n:
42
+ # 初始化当前最佳匹配
43
+ best_match = None
44
+ best_ratio = -1 # 设置匹配阈值
45
+
46
+ # 尝试从当前位置开始的不同组合
47
+ for l in range(i, min(i + 3, m)): # 限制向前查找的范围
48
+ current_a = ''.join(a_list[i:l + 1])
49
+
50
+ for r in range(j, min(j + 3, n)): # 限制向前查找的范围
51
+ current_b = ''.join(b_list[j:r + 1])
52
+
53
+ # 使用已有的match_span_by_char函数计算匹配度
54
+ span1, ratio1 = match_span_by_char(current_b, current_a)
55
+ span2, ratio2 = match_span_by_char(current_a, current_b)
56
+ ratio = ratio1 * ratio2
57
+
58
+ if ratio > best_ratio:
59
+ best_ratio = ratio
60
+ best_match = ((i, l + 1), (j, r + 1))
61
+
62
+ if best_match:
63
+ matches.append(best_match)
64
+ i = best_match[0][1]
65
+ j = best_match[1][1]
66
+ else:
67
+ # 如果没找到好的匹配,向前移动一步
68
+ i += 1
69
+ j += 1
70
+
71
+ matches.append(((i, m+1), (j, n+1)))
72
+
73
+ return matches
74
+
75
+ def get_chunk_changes(source_chunk_list, target_chunk_list):
76
+ SEPARATOR = "%|%"
77
+ source_text = SEPARATOR.join(source_chunk_list)
78
+ target_text = SEPARATOR.join(target_chunk_list)
79
+
80
+ # 初始化每个chunk的tag统计
81
+ source_chunk_stats = [{'delete_or_insert': 0, 'replace_or_equal': 0} for _ in source_chunk_list]
82
+ target_chunk_stats = [{'delete_or_insert': 0, 'replace_or_equal': 0} for _ in target_chunk_list]
83
+
84
+ # 获取chunk的起始位置列表
85
+ source_positions = [0]
86
+ target_positions = [0]
87
+ pos = 0
88
+ for chunk in source_chunk_list[:-1]:
89
+ pos += len(chunk) + len(SEPARATOR)
90
+ source_positions.append(pos)
91
+ source_positions.append(len(source_text))
92
+
93
+ pos = 0
94
+ for chunk in target_chunk_list[:-1]:
95
+ pos += len(chunk) + len(SEPARATOR)
96
+ target_positions.append(pos)
97
+ target_positions.append(len(target_text))
98
+
99
+ def update_chunk_stats(positions, stats, start, end, tag):
100
+ for i in range(len(positions) - 1):
101
+ chunk_start = positions[i]
102
+ chunk_end = positions[i + 1]
103
+
104
+ overlap_start = max(chunk_start, start)
105
+ overlap_end = min(chunk_end, end)
106
+
107
+ if overlap_end > overlap_start:
108
+ stats[i][tag] += overlap_end - overlap_start
109
+
110
+ matcher = SequenceMatcher(None, source_text, target_text)
111
+
112
+ # 处理每个操作块并更新统计信息
113
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
114
+ if tag == 'replace' or tag == 'equal':
115
+ update_chunk_stats(source_positions, source_chunk_stats, i1, i2, 'replace_or_equal')
116
+ update_chunk_stats(target_positions, target_chunk_stats, j1, j2, 'replace_or_equal')
117
+ elif tag == 'delete':
118
+ update_chunk_stats(source_positions, source_chunk_stats, i1, i2, 'delete_or_insert')
119
+ elif tag == 'insert':
120
+ update_chunk_stats(target_positions, target_chunk_stats, j1, j2, 'delete_or_insert')
121
+
122
+ # 确定每个chunk的最终tag
123
+ def get_final_tag(stats):
124
+ return 'delete_or_insert' if stats['delete_or_insert'] > stats['replace_or_equal'] else 'replace_or_equal'
125
+
126
+ source_chunk_tags = [get_final_tag(stats) for stats in source_chunk_stats]
127
+ target_chunk_tags = [get_final_tag(stats) for stats in target_chunk_stats]
128
+
129
+ # 使用双指针计算changes
130
+ changes = []
131
+ i = j = 0 # i指向source_chunk_list,j指向target_chunk_list
132
+ start_i = start_j = 0
133
+ m, n = len(source_chunk_list), len(target_chunk_list)
134
+ while i < m or j < n:
135
+ if i < m and source_chunk_tags[i] == 'delete_or_insert':
136
+ while i < m and source_chunk_tags[i] == 'delete_or_insert': i += 1
137
+ elif j < n and target_chunk_tags[j] == 'delete_or_insert':
138
+ while j < n and target_chunk_tags[j] == 'delete_or_insert': j += 1
139
+ elif i < m and j < n and source_chunk_tags[i] == 'replace_or_equal' and target_chunk_tags[j] == 'replace_or_equal':
140
+ while i < m and j < n and source_chunk_tags[i] == 'replace_or_equal' and target_chunk_tags[j] == 'replace_or_equal':
141
+ i += 1
142
+ j += 1
143
+ else:
144
+ # TODO: 这个算法目前还有一些问题,即equal的对应
145
+ break
146
+
147
+ # 当有任意一个指针移动时,检查是否需要添加change
148
+ if (i > start_i or j > start_j):
149
+ changes.append((start_i, i, start_j, j))
150
+ start_i, start_j = i, j
151
+
152
+ if (i < m or j < n):
153
+ changes.append((start_i, m, start_j, n))
154
+
155
+ return changes
156
+
157
+
158
+ # 使用示例
159
+ def test_get_chunk_changes():
160
+ source_chunks = ['', '', '', '第3章 初露锋芒\n在高人指导下,萧炎的斗气水平迅速提升,开始在家族中引起注意。\n', '', '第4章 异火初现\n萧炎得知“异火”的存在,决定踏上寻找异火的旅程。\n']
161
+ target_chunks = ['', '第3章 初露锋芒\n在高人指导下,萧炎的斗气水平迅速提升,开始在家族中引起注意。', '第3.5章 家族试炼\n萧炎参加家族举办的试炼,凭借新学的斗技和炼丹术,展现出超凡实力,获得家族长老的关注和认可。', '第4章 异火初现\n萧炎得知“异火”的存在,决定踏上寻找异火的旅程。']
162
+
163
+ changes = get_chunk_changes(source_chunks, target_chunks)
164
+ for change in changes:
165
+ print(f"Source chunks {change[0]}:{change[1]} -> Target chunks {change[2]}:{change[3]}")
166
+
167
+
168
+ for change in changes:
169
+ print('-' * 20)
170
+ print(f"{''.join(source_chunks[change[0]:change[1]])} -> {''.join(target_chunks[change[2]:change[3]])}")
171
+
172
+ if __name__ == "__main__":
173
+ test_get_chunk_changes()
core/draft_writer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.writer_utils import KeyPointMsg
2
+ from core.writer import Writer
3
+
4
+ from prompts.创作正文.prompt import main as prompt_draft
5
+ from prompts.提炼.prompt import main as prompt_summary
6
+
7
+
8
+ class DraftWriter(Writer):
9
+ def __init__(self, xy_pairs, global_context, model=None, sub_model=None, x_chunk_length=500, y_chunk_length=1000, max_thread_num=5):
10
+ super().__init__(xy_pairs, global_context, model, sub_model, x_chunk_length=x_chunk_length, y_chunk_length=y_chunk_length, max_thread_num=max_thread_num)
11
+
12
+ def write(self, user_prompt, pair_span=None):
13
+ target_chunk = self.get_chunk(pair_span=pair_span)
14
+ if not target_chunk.x_chunk:
15
+ raise Exception("需要提供剧情。")
16
+ if len(target_chunk.x_chunk) <= 5:
17
+ raise Exception("剧情不能少于5个字。")
18
+
19
+ chunks = self.get_chunks(pair_span)
20
+
21
+ yield from self.batch_write_apply_text(chunks, prompt_draft, user_prompt)
22
+
23
+ def summary(self, pair_span=None):
24
+ target_chunk = self.get_chunk(pair_span=pair_span)
25
+ if not target_chunk.y_chunk:
26
+ raise Exception("没有正文需要总结。")
27
+ if len(target_chunk.y_chunk) <= 5:
28
+ raise Exception("需要总结的正文不能少于5个字。")
29
+
30
+ # 先分割为更小的块,这样get_chunks才能正常工作
31
+ new_target_chunk = self.map_text_wo_llm(target_chunk)
32
+ self.apply_chunks([target_chunk], [new_target_chunk])
33
+ chunk_span = self.get_chunk_pair_span(new_target_chunk)
34
+
35
+ chunks = self.get_chunks(chunk_span, context_length_ratio=0)
36
+
37
+ yield from self.batch_write_apply_text(chunks, prompt_summary, "提炼剧情")
38
+
39
+ def split_into_chapters(self):
40
+ pass
41
+
42
+ def get_model(self):
43
+ return self.model
44
+
45
+ def get_sub_model(self):
46
+ return self.sub_model
core/frontend.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from rich.traceback import install
3
+ install(show_locals=False)
4
+
5
+ import gradio as gr
6
+ import yaml
7
+ import functools
8
+ import time
9
+ import sys
10
+ import os
11
+ import copy
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from config import RENDER_SAVE_LOAD_BTN, RENDER_STOP_BTN
17
+ from core.backend import call_write, call_accept, match_quote_text
18
+ from core.frontend_copy import enable_copy_js, on_copy
19
+ from core.frontend_setting import new_setting, render_setting
20
+ from core.frontend_utils import (
21
+ title, info,
22
+ create_progress_md, create_text_md, messages2chatbot,
23
+ init_writer, has_accept, is_running, try_cancel, writer_y_is_empty, writer_x_is_empty,
24
+ cancellable, process_writer_to_backend, process_writer_from_backend,
25
+ init_chapters_w, init_draft_w
26
+ )
27
+ from core.writer_utils import KeyPointMsg
28
+
29
+ from prompts.baseprompt import clean_txt_content, load_prompt
30
+
31
+
32
+ # 读取YAML文件
33
+ with open('prompts/idea-examples.yaml', 'r', encoding='utf-8') as file:
34
+ examples_data = yaml.safe_load(file)
35
+
36
+ # 准备示例列表
37
+ examples = [[example['idea']] for example in examples_data['examples']]
38
+
39
+ with gr.Blocks(head=enable_copy_js) as demo:
40
+ gr.HTML(title)
41
+ with gr.Accordion("使用指南"):
42
+ gr.Markdown(info)
43
+
44
+ writer_state = gr.State(init_writer('', check_empty=False))
45
+ setting_state = gr.State(new_setting())
46
+
47
+ if RENDER_SAVE_LOAD_BTN:
48
+ with gr.Row():
49
+ save_button = gr.Button("保存状态")
50
+ load_button = gr.Button("加载状态")
51
+ save_file_name = gr.Textbox(value='states', placeholder='输入文件名', lines=1, label=None, show_label=False, container=False)
52
+
53
+ def save_states(save_file_name, writer, setting):
54
+ import json
55
+ json_file_name = save_file_name + '.json'
56
+ with open(json_file_name, 'w', encoding='utf-8') as f:
57
+ json.dump({
58
+ 'writer': writer,
59
+ 'setting': setting
60
+ }, f, ensure_ascii=False, indent=2)
61
+ gr.Info(f"状态已保存到文件:{json_file_name}")
62
+
63
+ def load_states(save_file_name):
64
+ import json
65
+ json_file_name = save_file_name + '.json'
66
+ try:
67
+ with open(json_file_name, 'r', encoding='utf-8') as f:
68
+ states = json.load(f)
69
+ gr.Info(f"状态文件已加载:{json_file_name}")
70
+ states['setting']['render_time'] = time.time()
71
+ # 为了确保setting被渲染,选择模型是不会赋值setting_state的
72
+ # 需要保证setting界面持有的对象和setting_state是同一个
73
+ return states['writer'], states['setting']
74
+ except FileNotFoundError:
75
+ raise gr.Error(f"未找到保存的状态文件:{json_file_name}")
76
+
77
+ idea_textbox = gr.Textbox(placeholder='用一段话描述你要写的小说,或者从下方示例中选择一个创意...', lines=2, scale=1, label=None, show_label=False, container=False, max_length=1000)
78
+
79
+ gr.Examples(
80
+ label='示例',
81
+ examples=examples,
82
+ inputs=[idea_textbox],
83
+ )
84
+
85
+ with gr.Row():
86
+ outline_btn = gr.Button("创作大纲", scale=1, min_width=1, interactive = True, variant='primary')
87
+ chapters_btn = gr.Button("创作剧情", scale=1, min_width=1, interactive = False, variant='secondary')
88
+ draft_btn = gr.Button("创作正文", scale=1, min_width=1, interactive = False, variant='secondary')
89
+ auto_checkbox = gr.Checkbox(label='一键生成', scale=1, value=False, visible=False) # TODO: V1.10版本 “自动”尚不完善,暂不显示
90
+
91
+ progress_md = create_progress_md(writer_state.value)
92
+ text_md = create_text_md(writer_state.value)
93
+
94
+ @gr.render(inputs=writer_state)
95
+ def create_prompt_preview(writer):
96
+ prompt_outputs = writer['prompt_outputs'] if 'prompt_outputs' in writer else []
97
+ with gr.Accordion("Prompt预览", open=bool(prompt_outputs)):
98
+ pause_on_prompt_finished_checkbox = gr.Checkbox(label='允许在LLM响应完成后,预览Prompt', scale=1, value=writer['pause_on_prompt_finished_flag'])
99
+
100
+ for i, prompt_output in enumerate(prompt_outputs, 1):
101
+ with gr.Tab(f"Prompt {i}"):
102
+ gr.Chatbot(messages2chatbot(prompt_output['response_msgs']), type='messages')
103
+ if not prompt_outputs:
104
+ gr.Markdown('当前没有可预览的Prompt。')
105
+
106
+ continue_btn = gr.Button('继续', visible=bool(prompt_outputs), variant='primary')
107
+
108
+ def on_pause_on_prompt_finished(value):
109
+ if value:
110
+ gr.Info("在LLM响应完成后,将可以预览Prompt")
111
+ writer['pause_on_prompt_finished_flag'] = value
112
+
113
+ pause_on_prompt_finished_checkbox.change(on_pause_on_prompt_finished, [pause_on_prompt_finished_checkbox])
114
+
115
+ def on_continue(writer):
116
+ writer['pause_flag'] = False
117
+ writer['prompt_outputs'] = []
118
+ return writer
119
+
120
+ continue_btn.click(on_continue, writer_state, writer_state)
121
+
122
+
123
+ with gr.Row():
124
+ rewrite_all_button = gr.Button("开始创作", min_width=100, scale=2, variant='secondary', interactive=False)
125
+ suggestion_dropdown = gr.Dropdown(choices=[], min_width=100, scale=2, label=None, show_label=False, container=False, allow_custom_value=False)
126
+ quote_checkbox = gr.Checkbox(label='允许引用', min_width=100, scale=2, value=False)
127
+ gr.Textbox('窗口大小:', container=False, text_align='right', scale=1, min_width=100)
128
+ chunk_length_dropdown = gr.Dropdown(choices=[], min_width=80, scale=1, label=None, show_label=False, container=False, allow_custom_value=False)
129
+
130
+ quote_md = gr.Markdown(visible=False)
131
+
132
+ def on_quote_checkbox_change(writer, value):
133
+ if writer['current_w'] == 'outline_w':
134
+ gr.Info("大纲创作不支持引用\n考虑在剧情和正文创作中使用吧~")
135
+ return gr.update(value=False, visible=False)
136
+
137
+ if value:
138
+ gr.Info("允许引用(右键或Ctrl+C复制你想引用的文本)")
139
+ writer['quote_span'] = None
140
+ writer['quoted_text'] = ''
141
+ return gr.update(value=None, visible=False)
142
+
143
+ quote_checkbox.change(on_quote_checkbox_change, [writer_state, quote_checkbox], [quote_md])
144
+
145
+ def on_chunk_length_change(writer, value):
146
+ current_w_name = writer['current_w']
147
+ writer[current_w_name]['y_chunk_length'] = value
148
+ return gr.update(value=value)
149
+
150
+ chunk_length_dropdown.change(on_chunk_length_change, [writer_state, chunk_length_dropdown], [chunk_length_dropdown])
151
+
152
+ def on_copy_handle(text, writer, setting, quote_checkbox):
153
+ # gr.Info(f"Copy: {text}")
154
+ text = text.strip()
155
+
156
+ if has_accept(writer):
157
+ gr.Info('考虑先接受或拒绝修改哦~')
158
+ return gr.update(visible=False)
159
+
160
+ if len(text) < 10:
161
+ gr.Info('选中的文本太短,无法引用')
162
+ return gr.update(visible=False)
163
+
164
+ if quote_checkbox:
165
+ quote_span, quoted_text = match_quote_text(writer, setting, text)
166
+ if quote_span:
167
+ writer['quote_span'] = quote_span
168
+ writer['quoted_text'] = quoted_text
169
+ lines = quoted_text.split('\n')
170
+ if len(lines) > 10:
171
+ lines[5:-5] = ['......']
172
+ lines = ['```', ] + lines + ['```', ]
173
+ quoted_text = '\n'.join(["> " + e for e in lines])
174
+ return gr.update(value=quoted_text, visible=True)
175
+ else:
176
+ gr.Info('未找到匹配的引用文本')
177
+
178
+ writer['quote_span'] = None
179
+ writer['quoted_text'] = ''
180
+ return gr.update(visible=False)
181
+
182
+ on_copy(on_copy_handle, [writer_state, setting_state, quote_checkbox], [quote_md])
183
+
184
+
185
+ suggestion_textbox = gr.Textbox(max_length=1000, placeholder='在这里输入你的意见,或者从右上单选框选择', lines=2, scale=1, label=None, show_label=False, container=False)
186
+
187
+ with gr.Row():
188
+ accept_button = gr.Button("接受", scale=1, min_width=1, variant='secondary', interactive=False)
189
+ pause_button = gr.Button("暂停", scale=1, min_width=1, variant='secondary', visible=RENDER_STOP_BTN)
190
+ stop_button = gr.Button("取消", scale=1, min_width=1, variant='secondary')
191
+ flash_button = gr.Button("刷新", scale=1, min_width=1, variant='secondary')
192
+
193
+ def flash_interface(writer):
194
+ current_w_name = writer['current_w']
195
+
196
+ can_accept_flag = has_accept(writer) and not is_running(writer)
197
+ can_write_flag = not writer_x_is_empty(writer, current_w_name) and not can_accept_flag
198
+
199
+ match current_w_name:
200
+ case 'outline_w':
201
+ rewrite_all_button = gr.update(value='开始创作', variant='primary' if can_write_flag else 'secondary', interactive=can_write_flag)
202
+ case 'chapters_w':
203
+ rewrite_all_button = gr.update(value='开始创作', variant='primary' if can_write_flag else 'secondary', interactive=can_write_flag)
204
+ case 'draft_w':
205
+ rewrite_all_button = gr.update(value='开始创作', variant='primary' if can_write_flag else 'secondary', interactive=can_write_flag)
206
+
207
+ accept_button = gr.update(variant='primary' if can_accept_flag else 'secondary', interactive=can_accept_flag)
208
+
209
+ # 更新 chapters_btn 和 draft_btn 的 interactive 状态
210
+ outline_btn = gr.update(
211
+ variant='primary' if current_w_name == 'outline_w' else 'secondary'
212
+ )
213
+ chapters_btn = gr.update(
214
+ interactive=not writer_y_is_empty(writer, 'outline_w'),
215
+ variant='primary' if current_w_name == 'chapters_w' else 'secondary'
216
+ )
217
+ draft_btn = gr.update(
218
+ interactive=not writer_y_is_empty(writer, 'chapters_w'),
219
+ variant='primary' if current_w_name == 'draft_w' else 'secondary'
220
+ )
221
+
222
+ pause_button = gr.update(
223
+ value="继续" if writer['pause_flag'] else "暂停",
224
+ variant='secondary',
225
+ )
226
+
227
+ suggestion_choices = writer['suggestions'][current_w_name]
228
+ # suggestion_choices = ['自动', ] + writer['suggestions'][current_w_name] # TODO: V1.10版本 “自动”尚不完善,暂不显示
229
+ if writer_y_is_empty(writer, current_w_name):
230
+ suggestion_dropdown = gr.update(choices=suggestion_choices, value=suggestion_choices[0])
231
+ else:
232
+ suggestion_dropdown = gr.update(choices=suggestion_choices,)
233
+
234
+ chunk_length_choices = writer['chunk_length'][current_w_name]
235
+ if cur_chunk_length := writer[current_w_name].get('y_chunk_length', None):
236
+ chunk_length_dropdown = gr.update(choices=chunk_length_choices, value=cur_chunk_length)
237
+ else:
238
+ chunk_length_dropdown = gr.update(choices=chunk_length_choices, value=chunk_length_choices[0])
239
+
240
+ return (
241
+ create_text_md(writer),
242
+ create_progress_md(writer),
243
+ rewrite_all_button,
244
+ accept_button,
245
+ outline_btn,
246
+ chapters_btn,
247
+ draft_btn,
248
+ pause_button,
249
+ suggestion_dropdown,
250
+ chunk_length_dropdown
251
+ )
252
+
253
+ # 更新 flash_event 字典以包含新的输出
254
+ flash_event = dict(
255
+ fn=flash_interface,
256
+ inputs=[writer_state],
257
+ outputs=[
258
+ text_md,
259
+ progress_md,
260
+ rewrite_all_button,
261
+ accept_button,
262
+ outline_btn,
263
+ chapters_btn,
264
+ draft_btn,
265
+ pause_button,
266
+ suggestion_dropdown,
267
+ chunk_length_dropdown
268
+ ]
269
+ )
270
+
271
+ flash_button.click(**flash_event)
272
+ if RENDER_SAVE_LOAD_BTN:
273
+ save_button.click(save_states, inputs=[save_file_name, writer_state, setting_state], outputs=[])
274
+ load_button.click(load_states, inputs=[save_file_name], outputs=[writer_state, setting_state]).success(**flash_event)
275
+ # stop_write_long_novel_button.click(on_cancel, inputs=[writer_state])
276
+ stop_button.click(try_cancel, inputs=[writer_state]).success(**flash_event).success(lambda :gr.update(), None, writer_state)
277
+ # TODO: stop_btn对writer_state的更新没有起效
278
+
279
+ @cancellable
280
+ def _on_write_all(writer, setting, auto_write=False, suggestion=None):
281
+ current_w_name = writer['current_w']
282
+
283
+ if writer_x_is_empty(writer, current_w_name):
284
+ gr.Info('请先输入需要创作的内容!')
285
+ return
286
+
287
+ writer['prompt_outputs'].clear()
288
+
289
+ if writer['quote_span']:
290
+ quote_span, quoted_text = match_quote_text(writer, setting, writer['quoted_text'])
291
+ if quote_span != writer['quote_span'] or quoted_text != writer['quoted_text']:
292
+ raise gr.Error('引用文本不存在!')
293
+
294
+ generator = call_write(process_writer_to_backend(writer), setting, auto_write, suggestion)
295
+
296
+ new_writer = None
297
+ while True:
298
+ try:
299
+ kp_msg = next(generator)
300
+ if isinstance(kp_msg, KeyPointMsg):
301
+ # TODO: 由于KeyPointMsg的设计问题,这里的逻辑比较复杂,后续可以考虑优化
302
+ if kp_msg.is_prompt() and kp_msg.is_finished() and writer['pause_on_prompt_finished_flag']:
303
+ gr.Info('LLM响应完成,可以预览Prompt')
304
+ writer['pause_flag'] = True
305
+ if new_writer is None: continue
306
+ elif kp_msg.is_title(): # TODO: 标题节点还未实现finish逻辑
307
+ # if new_writer is not None:
308
+ # # 说明这是一个关键节点,进行保存
309
+ # process_writer_from_backend(writer, new_writer)
310
+ # yield create_text_md(writer), writer
311
+ # gr.Info(f'已自动保存进度')
312
+ continue
313
+ # 关键节点保存的逻辑比较复杂,有bug,之后版本考虑提供
314
+ else:
315
+ continue
316
+ else:
317
+ new_writer = kp_msg
318
+
319
+ if writer['pause_flag']:
320
+ writer['prompt_outputs'] = copy.deepcopy(new_writer['prompt_outputs'])
321
+ # 将prompt_outputs传递到writer_state中,使得暂停时能显示prompt, 需要序列化,否则writer会不断更新,导致prompt不断渲染
322
+ yield create_text_md(new_writer), writer
323
+
324
+ while writer['pause_flag'] and not writer['cancel_flag']:
325
+ time.sleep(0.1)
326
+ else:
327
+ yield create_text_md(new_writer), gr.update()
328
+ except StopIteration as e:
329
+ # 这里处理最终状态
330
+ process_writer_from_backend(writer, e.value)
331
+ yield create_text_md(writer), writer
332
+ if has_accept(writer):
333
+ gr.Info('创作完成!点击接受按钮接受修改。')
334
+ else:
335
+ gr.Info('本次创作没有任何更改。') # 通常因为审阅意见认为无需更改
336
+ return
337
+
338
+ def on_auto_write_all(writer, setting, auto_write):
339
+ if auto_write:
340
+ yield from _on_write_all(writer, setting, True)
341
+ else:
342
+ pass
343
+ # suggestion = writer['suggestions'][writer['current_w']][0]
344
+ # yield from _on_write_all(writer, setting, False, suggestion)
345
+
346
+ writer_all_events = dict(
347
+ fn=on_auto_write_all,
348
+ queue=True,
349
+ inputs=[writer_state, setting_state, auto_checkbox],
350
+ outputs=[text_md, writer_state],
351
+ concurrency_limit=10
352
+ )
353
+
354
+ def on_init_outline(idea, writer):
355
+ if not idea.strip():
356
+ gr.Info("先输入小说简介或从示例中选择一个")
357
+ return gr.update()
358
+ new_writer = init_writer(idea)
359
+ writer.update({
360
+ k:v for k, v in new_writer.items() if k in ['current_w', 'outline_w', 'prompt_outputs']
361
+ })
362
+ return writer
363
+
364
+ outline_btn.click(on_init_outline, inputs=[idea_textbox, writer_state], outputs=[writer_state]).success(**writer_all_events).then(**flash_event)
365
+ chapters_btn.click(lambda writer: init_chapters_w(writer), inputs=[writer_state], outputs=[writer_state]).success(**writer_all_events).then(**flash_event)
366
+ draft_btn.click(lambda writer: init_draft_w(writer), inputs=[writer_state], outputs=[writer_state]).success(**writer_all_events).then(**flash_event)
367
+
368
+ def on_select_suggestion(writer, setting, choice):
369
+ if choice == '自动':
370
+ return gr.update(value=choice, visible=False)
371
+
372
+ current_w_name = writer['current_w']
373
+ dirname = writer['suggestions_dirname'][current_w_name]
374
+ suggestion = clean_txt_content(load_prompt(dirname, choice))
375
+ if suggestion.startswith("user:\n"):
376
+ suggestion = suggestion[len("user:\n"):]
377
+
378
+ return gr.update(value=suggestion, visible=True)
379
+
380
+ suggestion_dropdown.change(on_select_suggestion, inputs=[writer_state, setting_state, suggestion_dropdown], outputs=[suggestion_textbox])
381
+
382
+ def on_write_all(writer, setting, suggestion):
383
+ if not suggestion.strip():
384
+ gr.Info('需要输入创作意见!')
385
+ return
386
+ yield from _on_write_all(writer, setting, False, suggestion)
387
+
388
+ rewrite_all_button.click(
389
+ on_write_all,
390
+ queue=True,
391
+ inputs=[writer_state, setting_state, suggestion_textbox],
392
+ outputs=[text_md, writer_state],
393
+ concurrency_limit=10
394
+ ).then(**flash_event)
395
+
396
+
397
+ @cancellable
398
+ def on_accept_write(writer, setting):
399
+ current_w_name = writer['current_w']
400
+ current_w = writer[current_w_name]
401
+
402
+ if not current_w['apply_chunks']:
403
+ raise gr.Error('请先进行创作!')
404
+
405
+ new_writer = call_accept(process_writer_to_backend(writer), setting)
406
+ process_writer_from_backend(writer, new_writer)
407
+ yield create_text_md(writer), writer
408
+
409
+ accept_button.click(fn=on_accept_write, inputs=[writer_state, setting_state], outputs=[text_md, writer_state]).then(**flash_event)
410
+
411
+ def toggle_pause(writer):
412
+ if not is_running(writer):
413
+ gr.Info('当前没有正在进行的操作')
414
+ return gr.update()
415
+
416
+ writer['pause_flag'] = not writer['pause_flag']
417
+ # gr.Info('已' + ('暂停' if writer['pause_flag'] else '继续') + '操作')
418
+ return gr.update(value="暂停" if not writer['pause_flag'] else "继续")
419
+
420
+ pause_button.click(
421
+ toggle_pause,
422
+ inputs=[writer_state],
423
+ outputs=[pause_button]
424
+ )
425
+
426
+ @gr.render(inputs=setting_state)
427
+ def _render_setting(setting):
428
+ return render_setting(setting, setting_state)
429
+
430
+
431
+ demo.queue()
432
+ demo.launch(server_name="0.0.0.0", server_port=7860)
433
+ #demo.launch()
434
+
435
+
core/frontend_copy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ enable_copy_js = """
4
+ <script>
5
+ document.addEventListener('copy', function(e) {
6
+ // 获取选中的文本
7
+ var selectedText = window.getSelection().toString();
8
+ if(selectedText) {
9
+ // 直接触发 gradio 组件的更新
10
+ const textbox = document.getElementById('copy_textbox');
11
+ if(textbox) {
12
+ textbox.querySelector('textarea').value = selectedText;
13
+ // 触发 change 事件以更新 Gradio 状态
14
+ textbox.querySelector('textarea').dispatchEvent(new Event('input', { bubbles: true }));
15
+ }
16
+ }
17
+ });
18
+ </script>
19
+ """
20
+
21
+ def on_copy(fn, inputs, outputs):
22
+ copy_textbox = gr.Textbox(elem_id="copy_textbox", visible=False)
23
+ return copy_textbox.change(fn, [copy_textbox] + inputs, outputs)
24
+
25
+
26
+ # with gr.Blocks(head=enable_copy_js) as demo:
27
+ # gr.Markdown("Hello\nTest Copy")
28
+ # copy_textbox = gr.Textbox(elem_id="copy_textbox", visible=False)
29
+
30
+ # def copy_handle(text):
31
+ # gr.Info(text)
32
+
33
+ # copy_textbox.change(copy_handle, copy_textbox)
34
+
35
+ # demo.launch()
core/frontend_setting.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from enum import Enum, auto
3
+
4
+ from llm_api import ModelConfig, wenxin_model_config, doubao_model_config, gpt_model_config, zhipuai_model_config, test_stream_chat
5
+ from config import API_SETTINGS, RENDER_SETTING_API_TEST_BTN, ENABLE_SETTING_SELECT_SUB_MODEL
6
+
7
+
8
+ class Provider:
9
+ GPT = "GPT(OpenAI)"
10
+ WENXIN = "文心(百度)"
11
+ DOUBAO = "豆包(字节跳动)"
12
+ ZHIPUAI = "GLM(智谱)"
13
+ OTHERS = '其他'
14
+
15
+ def deep_update(d, u):
16
+ """Recursively update dictionary d with values from dictionary u"""
17
+ for k, v in u.items():
18
+ if isinstance(v, dict) and k in d and isinstance(d[k], dict):
19
+ deep_update(d[k], v)
20
+ else:
21
+ d[k] = v
22
+
23
+ def new_setting():
24
+ model_config = API_SETTINGS.pop('model')
25
+ sub_model_config = API_SETTINGS.pop('sub_model')
26
+
27
+ new_setting = dict(
28
+ model=ModelConfig(**model_config),
29
+ sub_model=ModelConfig(**sub_model_config),
30
+ render_count=0,
31
+ provider_name=Provider.GPT,
32
+ wenxin={
33
+ 'ak': '',
34
+ 'sk': '',
35
+ 'default_model': 'ERNIE-Novel-8K',
36
+ 'default_sub_model': 'ERNIE-3.5-8K',
37
+ 'available_models': list(wenxin_model_config.keys())
38
+ },
39
+ doubao={
40
+ 'api_key': '',
41
+ 'main_endpoint_id': '',
42
+ 'sub_endpoint_id': '',
43
+ 'default_model': 'doubao-pro-32k',
44
+ 'default_sub_model': 'doubao-lite-32k',
45
+ 'available_models': list(doubao_model_config.keys())
46
+ },
47
+ gpt={
48
+ 'api_key': '',
49
+ 'base_url': '',
50
+ 'proxies': '',
51
+ 'default_model': 'gpt-4o',
52
+ 'default_sub_model': 'gpt-4o-mini',
53
+ 'available_models': list(gpt_model_config.keys())
54
+ },
55
+ zhipuai={
56
+ 'api_key': '',
57
+ 'default_model': 'glm-4-plus',
58
+ 'default_sub_model': 'glm-4-flashx',
59
+ 'available_models': list(zhipuai_model_config.keys())
60
+ },
61
+ others={
62
+ 'api_key': '',
63
+ 'base_url': '',
64
+ 'default_model': '',
65
+ 'default_sub_model': '',
66
+ 'available_models': []
67
+ }
68
+ )
69
+
70
+ deep_update(new_setting, API_SETTINGS)
71
+
72
+ return new_setting
73
+
74
+ # @gr.render(inputs=setting_state)
75
+ def render_setting(setting, setting_state):
76
+ with gr.Accordion("API 设置"):
77
+ with gr.Row():
78
+ provider_name = gr.Dropdown(
79
+ choices=[Provider.GPT, Provider.WENXIN, Provider.DOUBAO, Provider.ZHIPUAI, Provider.OTHERS],
80
+ value=setting['provider_name'],
81
+ label="模型提供商",
82
+ scale=1
83
+ )
84
+
85
+ def on_select_provider(provider_name):
86
+ setting['provider_name'] = provider_name
87
+ return setting
88
+
89
+ provider_name.select(fn=on_select_provider, inputs=provider_name, outputs=[setting_state])
90
+
91
+ match setting['provider_name']:
92
+ case Provider.WENXIN:
93
+ provider_config = setting['wenxin']
94
+ case Provider.DOUBAO:
95
+ provider_config = setting['doubao']
96
+ case Provider.GPT:
97
+ provider_config = setting['gpt']
98
+ case Provider.ZHIPUAI:
99
+ provider_config = setting['zhipuai']
100
+ case Provider.OTHERS:
101
+ provider_config = setting['others']
102
+
103
+ main_model = gr.Dropdown(
104
+ choices=provider_config['available_models'],
105
+ value=provider_config['default_model'],
106
+ label="主模型",
107
+ scale=1,
108
+ allow_custom_value=setting['provider_name'] == Provider.OTHERS
109
+ )
110
+
111
+ sub_model = gr.Dropdown(
112
+ choices=provider_config['available_models'],
113
+ value=provider_config['default_sub_model'],
114
+ label="辅助模型",
115
+ scale=1,
116
+ allow_custom_value=setting['provider_name'] == Provider.OTHERS,
117
+ interactive=ENABLE_SETTING_SELECT_SUB_MODEL
118
+ )
119
+
120
+ with gr.Row():
121
+ if setting['provider_name'] == Provider.WENXIN:
122
+ baidu_access_key = gr.Textbox(
123
+ value=provider_config['ak'],
124
+ label='Baidu Access Key',
125
+ lines=1,
126
+ placeholder='Enter your Baidu access key here',
127
+ interactive=True,
128
+ scale=10,
129
+ type='password'
130
+ )
131
+ baidu_secret_key = gr.Textbox(
132
+ value=provider_config['sk'],
133
+ label='Baidu Secret Key',
134
+ lines=1,
135
+ placeholder='Enter your Baidu secret key here',
136
+ interactive=True,
137
+ scale=10,
138
+ type='password'
139
+ )
140
+
141
+ elif setting['provider_name'] == Provider.DOUBAO:
142
+ doubao_api_key = gr.Textbox(
143
+ value=provider_config['api_key'],
144
+ label='Doubao API Key',
145
+ lines=1,
146
+ placeholder='Enter your Doubao API key here',
147
+ interactive=True,
148
+ scale=10,
149
+ type='password'
150
+ )
151
+ main_endpoint_id = gr.Textbox(
152
+ value=provider_config['main_endpoint_id'],
153
+ label='Main Endpoint ID',
154
+ lines=1,
155
+ placeholder='Enter your main endpoint ID here',
156
+ interactive=True,
157
+ scale=10,
158
+ type='password'
159
+ )
160
+ sub_endpoint_id = gr.Textbox(
161
+ value=provider_config['sub_endpoint_id'],
162
+ label='Sub Endpoint ID',
163
+ lines=1,
164
+ placeholder='Enter your sub endpoint ID here',
165
+ interactive=True,
166
+ scale=10,
167
+ type='password'
168
+ )
169
+
170
+ elif setting['provider_name'] in [Provider.GPT, Provider.OTHERS]:
171
+ gpt_api_key = gr.Textbox(
172
+ value=provider_config['api_key'],
173
+ label='OpenAI API Key',
174
+ lines=1,
175
+ placeholder='Enter your OpenAI API key here',
176
+ interactive=True,
177
+ scale=10,
178
+ type='password'
179
+ )
180
+ base_url = gr.Textbox(
181
+ value=provider_config['base_url'],
182
+ label='API Base URL',
183
+ lines=1,
184
+ placeholder='Enter API base URL here',
185
+ interactive=True,
186
+ scale=10,
187
+ type='password'
188
+ )
189
+
190
+ elif setting['provider_name'] == Provider.ZHIPUAI:
191
+ zhipuai_api_key = gr.Textbox(
192
+ value=provider_config['api_key'],
193
+ label='ZhipuAI API Key',
194
+ lines=1,
195
+ placeholder='Enter your ZhipuAI API key here',
196
+ interactive=True,
197
+ scale=10,
198
+ type='password'
199
+ )
200
+
201
+ with gr.Row():
202
+ if setting['provider_name'] == Provider.WENXIN:
203
+ def on_submit(main_model, sub_model, baidu_access_key, baidu_secret_key):
204
+ provider_config['ak'] = baidu_access_key
205
+ provider_config['sk'] = baidu_secret_key
206
+
207
+ setting['model'] = ModelConfig(
208
+ model=main_model,
209
+ ak=baidu_access_key,
210
+ sk=baidu_secret_key,
211
+ max_tokens=4096
212
+ )
213
+ setting['sub_model'] = ModelConfig(
214
+ model=sub_model,
215
+ ak=baidu_access_key,
216
+ sk=baidu_secret_key,
217
+ max_tokens=4096
218
+ )
219
+
220
+ submit_event = dict(
221
+ fn=on_submit,
222
+ inputs=[main_model, sub_model, baidu_access_key, baidu_secret_key],
223
+ )
224
+
225
+ on_submit(main_model.value, sub_model.value, baidu_access_key.value, baidu_secret_key.value)
226
+
227
+ main_model.change(**submit_event)
228
+ sub_model.change(**submit_event)
229
+ baidu_access_key.change(**submit_event)
230
+ baidu_secret_key.change(**submit_event)
231
+
232
+ elif setting['provider_name'] == Provider.DOUBAO:
233
+ def on_submit(main_model, sub_model, doubao_api_key, main_endpoint_id, sub_endpoint_id):
234
+ provider_config['api_key'] = doubao_api_key
235
+ provider_config['main_endpoint_id'] = main_endpoint_id
236
+ provider_config['sub_endpoint_id'] = sub_endpoint_id
237
+
238
+ setting['model'] = ModelConfig(
239
+ model=main_model,
240
+ api_key=doubao_api_key,
241
+ endpoint_id=main_endpoint_id,
242
+ max_tokens=4096
243
+ )
244
+ setting['sub_model'] = ModelConfig(
245
+ model=sub_model,
246
+ api_key=doubao_api_key,
247
+ endpoint_id=sub_endpoint_id,
248
+ max_tokens=4096
249
+ )
250
+
251
+ submit_event = dict(
252
+ fn=on_submit,
253
+ inputs=[main_model, sub_model, doubao_api_key, main_endpoint_id, sub_endpoint_id],
254
+ )
255
+
256
+ on_submit(main_model.value, sub_model.value, doubao_api_key.value, main_endpoint_id.value, sub_endpoint_id.value)
257
+
258
+ main_model.change(**submit_event)
259
+ sub_model.change(**submit_event)
260
+ doubao_api_key.change(**submit_event)
261
+ main_endpoint_id.change(**submit_event)
262
+ sub_endpoint_id.change(**submit_event)
263
+
264
+ elif setting['provider_name'] in [Provider.GPT, Provider.OTHERS]:
265
+ def on_submit(main_model, sub_model, gpt_api_key, base_url):
266
+ provider_config['api_key'] = gpt_api_key
267
+ provider_config['base_url'] = base_url.strip()
268
+
269
+ setting['model'] = ModelConfig(
270
+ model=main_model,
271
+ api_key=provider_config['api_key'],
272
+ base_url=provider_config['base_url'],
273
+ max_tokens=4096,
274
+ proxies=provider_config.get('proxies', None),
275
+ )
276
+ setting['sub_model'] = ModelConfig(
277
+ model=sub_model,
278
+ api_key=provider_config['api_key'],
279
+ base_url=provider_config['base_url'],
280
+ max_tokens=4096,
281
+ proxies=provider_config.get('proxies', None),
282
+ )
283
+
284
+ submit_event = dict(
285
+ fn=on_submit,
286
+ inputs=[main_model, sub_model, gpt_api_key, base_url],
287
+ )
288
+
289
+ on_submit(main_model.value, sub_model.value, gpt_api_key.value, base_url.value)
290
+
291
+ main_model.change(**submit_event)
292
+ sub_model.change(**submit_event)
293
+ gpt_api_key.change(**submit_event)
294
+ base_url.change(**submit_event)
295
+
296
+ elif setting['provider_name'] == Provider.ZHIPUAI:
297
+ def on_submit(main_model, sub_model, zhipuai_api_key):
298
+ provider_config['api_key'] = zhipuai_api_key
299
+
300
+ setting['model'] = ModelConfig(
301
+ model=main_model,
302
+ api_key=zhipuai_api_key,
303
+ max_tokens=4096
304
+ )
305
+ setting['sub_model'] = ModelConfig(
306
+ model=sub_model,
307
+ api_key=zhipuai_api_key,
308
+ max_tokens=4096
309
+ )
310
+
311
+ submit_event = dict(
312
+ fn=on_submit,
313
+ inputs=[main_model, sub_model, zhipuai_api_key],
314
+ )
315
+
316
+ on_submit(main_model.value, sub_model.value, zhipuai_api_key.value)
317
+
318
+ main_model.change(**submit_event)
319
+ sub_model.change(**submit_event)
320
+ zhipuai_api_key.change(**submit_event)
321
+
322
+ if RENDER_SETTING_API_TEST_BTN:
323
+ test_btn = gr.Button("测试")
324
+ test_report = gr.Textbox(show_label=False, container=False, value='', interactive=False, scale=10)
325
+
326
+ def on_test_llm_api():
327
+ if not setting['model']['model'].strip():
328
+ return gr.Info('主模型名不能为空')
329
+
330
+ if not setting['sub_model']['model'].strip():
331
+ return gr.Info('辅助模型名不能为空')
332
+
333
+ try:
334
+ response1 = yield from test_stream_chat(setting['model'])
335
+ response2 = yield from test_stream_chat(setting['sub_model'])
336
+ report_text = f"User:1+1=?\n主模型 :{response1.response}({response1.cost_info})\n辅助模型:{response2.response}({response2.cost_info})\n测试通过!"
337
+ yield report_text
338
+ except Exception as e:
339
+ yield f"测试失败:{str(e)}"
340
+
341
+ if RENDER_SETTING_API_TEST_BTN:
342
+ test_btn.click(
343
+ on_test_llm_api,
344
+ outputs=[test_report]
345
+ )
core/frontend_utils.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import pickle
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+
8
+ from core.writer import Chunk
9
+
10
+ title = """
11
+ <div style="text-align: center; padding: 10px 20px;">
12
+ <h1 style="margin: 0 0 5px 0;">🖋️ Long-Novel-GPT 1.10</h1>
13
+ <p style="margin: 0;"><em>AI一键生成长篇小说</em></p>
14
+ </div>
15
+ """
16
+
17
+ info = \
18
+ """1. 当前Demo支持GPT、Claude、文心、豆包、GLM等模型,并且已经配置了API-Key,默认模型为GPT4o,最大线程数为5。
19
+ 2. 可以选中**示例**中的任意一个创意,然后点击**创作大纲**来初始化大纲。
20
+ 3. 初始化后,点击**开始创作**按钮,可以不断创作大纲,直到满意为止。
21
+ 4. 创建完大纲后,点击**创作剧情**按钮,之后重复以上流程。
22
+ 5. 选中**一键生成**后,再次点击左侧按钮可以一键生成。
23
+ 6. 如果遇到任何无法解决的问题,请点击**刷新**按钮。
24
+ 7. 如果问题还是无法解决,请刷新浏览器页面,这会导致丢失所有数据,请手动备份重要文本。
25
+ """
26
+
27
+
28
+ def init_writer(idea, check_empty=True):
29
+ outline_w = dict(
30
+ current_cost=0,
31
+ total_cost=0,
32
+ currency_symbol='¥',
33
+ xy_pairs=[(idea, '')],
34
+ apply_chunks={},
35
+ )
36
+ chapters_w = dict(
37
+ current_cost=0,
38
+ total_cost=0,
39
+ currency_symbol='¥',
40
+ xy_pairs=[('', '')],
41
+ apply_chunks={},
42
+ )
43
+ draft_w = dict(
44
+ current_cost=0,
45
+ total_cost=0,
46
+ currency_symbol='¥',
47
+ xy_pairs=[('', '')],
48
+ apply_chunks={},
49
+ )
50
+ suggestions = dict(
51
+ outline_w = ['新建大纲', '扩写大纲', '润色大纲'],
52
+ chapters_w = ['新建剧情', '扩写剧情', '润色剧情'],
53
+ draft_w = ['新建正文', '扩写正文', '润色正文'],
54
+ )
55
+
56
+ suggestions_dirname = dict(
57
+ outline_w = 'prompts/创作大纲',
58
+ chapters_w = 'prompts/创作剧情',
59
+ draft_w = 'prompts/创作正文',
60
+ )
61
+
62
+ chunk_length = dict(
63
+ outline_w = [4_000, ],
64
+ chapters_w = [500, 200, 1000, 2000],
65
+ draft_w = [1000, 500, 2000, 3000],
66
+ )
67
+
68
+ writer = dict(
69
+ current_w='outline_w',
70
+ outline_w=outline_w,
71
+ chapters_w=chapters_w,
72
+ draft_w=draft_w,
73
+ running_flag=False,
74
+ cancel_flag=False, # 用于取消正在进行的操作
75
+ pause_flag=False, # 用于暂停操作
76
+ progress={},
77
+ prompt_outputs=[], # 这一行未注释时,将在gradio界面中显示prompt_outputs
78
+ suggestions=suggestions,
79
+ suggestions_dirname=suggestions_dirname,
80
+ pause_on_prompt_finished_flag = False,
81
+ quote_span = None,
82
+ chunk_length = chunk_length,
83
+ )
84
+
85
+ current_w_name = writer['current_w']
86
+ if check_empty and writer_x_is_empty(writer, current_w_name):
87
+ raise Exception('请先输入小说简介!')
88
+ else:
89
+ return writer
90
+
91
+ def init_chapters_w(writer, check_empty=True):
92
+ outline_w = writer['outline_w']
93
+ chapters_w = writer['chapters_w']
94
+ outline_y = "".join([e[1] for e in outline_w['xy_pairs']])
95
+ chapters_w['xy_pairs'] = [(outline_y, '')]
96
+
97
+ writer["current_w"] = "chapters_w"
98
+
99
+ current_w_name = writer['current_w']
100
+ if check_empty and writer_x_is_empty(writer, current_w_name):
101
+ raise Exception('大纲不能为空')
102
+ else:
103
+ return writer
104
+
105
+ def init_draft_w(writer, check_empty=True):
106
+ chapters_w = writer['chapters_w']
107
+ draft_w = writer['draft_w']
108
+ chapters_y = "".join([e[1] for e in chapters_w['xy_pairs']])
109
+ draft_w['xy_pairs'] = [(chapters_y, '')]
110
+
111
+ writer["current_w"] = "draft_w"
112
+
113
+ current_w_name = writer['current_w']
114
+ if check_empty and writer_x_is_empty(writer, current_w_name):
115
+ raise Exception('剧情不能为空')
116
+ else:
117
+ return writer
118
+
119
+ # 在将writer传递到backend之前,只传递backend需要的部分
120
+ # 这样从backend返回new_writer后,可以直接用update更新writer_state
121
+ def process_writer_to_backend(writer):
122
+ remained_keys = ['current_w', 'outline_w', 'chapters_w', 'draft_w', 'quote_span']
123
+ new_writer = {key: writer[key] for key in remained_keys}
124
+ return copy.deepcopy(new_writer)
125
+
126
+ # 在整个writer_state生命周期中,其对象地址都不应被改变,这样方便各种flag的检查
127
+ def process_writer_from_backend(writer, new_writer):
128
+ for key in ['outline_w', 'chapters_w', 'draft_w']:
129
+ writer[key] = copy.deepcopy(new_writer[key])
130
+ return writer
131
+
132
+ def is_running(writer):
133
+ # 只检查是否有正在运行的操作
134
+ return writer['running_flag'] and not writer['cancel_flag']
135
+
136
+ def has_accept(writer):
137
+ # 只检查是否有待接受的文本
138
+ current_w = writer[writer['current_w']]
139
+ return bool(current_w['apply_chunks'])
140
+
141
+ def cancellable(func):
142
+ @functools.wraps(func)
143
+ def wrapper(writer, *args, **kwargs):
144
+ if is_running(writer):
145
+ gr.Warning('另一个操作正在进行中,请等待其完成或取消!')
146
+ return
147
+
148
+ if has_accept(writer) and wrapper.__name__ != "on_accept_write":
149
+ gr.Warning('有正在等待接受的文本,点击接受或取消!')
150
+ return
151
+
152
+ writer['running_flag'] = True
153
+ writer['cancel_flag'] = False
154
+ writer['pause_flag'] = False
155
+
156
+ generator = func(writer, *args, **kwargs)
157
+ result = None
158
+ try:
159
+ while True:
160
+ if writer['cancel_flag']:
161
+ gr.Info('操作已取消!')
162
+ return
163
+
164
+ # pause 暂停逻辑由func内部实现,便于它们在暂停前后执行一些操作
165
+ try:
166
+ result = next(generator)
167
+ if isinstance(result, tuple) and (writer_dict := next((item for item in result if isinstance(item, dict) and 'running_flag' in item), None)):
168
+ assert writer is writer_dict, 'writer对象地址发生了改变'
169
+ writer = writer_dict
170
+ yield result
171
+ except StopIteration as e:
172
+ return e.value
173
+ except Exception as e:
174
+ raise gr.Error(f'操作过程中发生错误:{e}')
175
+ finally:
176
+ writer['running_flag'] = False
177
+ writer['pause_flag'] = False
178
+
179
+ return wrapper
180
+
181
+ def try_cancel(writer):
182
+ if not (is_running(writer) or has_accept(writer)):
183
+ gr.Info('当前没有正在进行的操作或待接受的文本')
184
+ return
185
+
186
+ writer['prompt_outputs'] = []
187
+ current_w = writer[writer['current_w']]
188
+ if not is_running(writer) and has_accept(writer): # 优先取消待接受的文本
189
+ current_w['apply_chunks'].clear()
190
+ gr.Info('已取消待接受的文本')
191
+ return
192
+
193
+ writer['cancel_flag'] = True
194
+
195
+ start_time = time.time()
196
+ while writer['running_flag'] and time.time() - start_time < 3:
197
+ time.sleep(0.1)
198
+
199
+ if writer['running_flag']:
200
+ gr.Warning('取消操作超时,可能需要刷新页面')
201
+
202
+ writer['cancel_flag'] = False
203
+
204
+ def writer_y_is_empty(writer, w_name):
205
+ xy_pairs = writer[w_name]['xy_pairs']
206
+ return sum(len(e[1]) for e in xy_pairs) == 0
207
+
208
+ def writer_x_is_empty(writer, w_name):
209
+ xy_pairs = writer[w_name]['xy_pairs']
210
+ return sum(len(e[0]) for e in xy_pairs) == 0
211
+
212
+
213
+ # create a markdown table
214
+ # TODO: 优化显示逻辑,字少的列宽度小,字多的列宽度大
215
+ def create_comparison_table(pairs, column_names=['Original Text', 'Enhanced Text', 'Enhanced Text 2']):
216
+ # Check if any pair has 3 elements
217
+ has_third_column = any(len(pair) == 3 for pair in pairs)
218
+
219
+ # Create table header
220
+ if has_third_column:
221
+ table = f"| {column_names[0]} | {column_names[1]} | {column_names[2]} |\n|---------------|-----------------|----------------|\n"
222
+ else:
223
+ table = f"| {column_names[0]} | {column_names[1]} |\n|---------------|---------------|\n"
224
+
225
+ # Add rows to the table
226
+ for pair in pairs:
227
+ x = pair[0].replace('|', '\\|').replace('\n', '<br>')
228
+ y1 = pair[1].replace('|', '\\|').replace('\n', '<br>')
229
+
230
+ if has_third_column:
231
+ y2 = pair[2].replace('|', '\\|').replace('\n', '<br>') if len(pair) == 3 else ''
232
+ table += f"| {x} | {y1} | {y2} |\n"
233
+ else:
234
+ table += f"| {x} | {y1} |\n"
235
+
236
+ return table
237
+
238
+ def messages2chatbot(messages):
239
+ if len(messages) and messages[0]['role'] == 'system':
240
+ return [{'role': 'user', 'content': messages[0]['content']}, ] + messages[1:]
241
+ else:
242
+ return messages
243
+
244
+ def create_progress_md(writer):
245
+ progress_md = ""
246
+ if 'progress' in writer and writer['progress']:
247
+ progress = writer['progress']
248
+ progress_md = ""
249
+
250
+ # 使用集合来去重并保持顺序
251
+ titles = []
252
+ subtitles = {}
253
+ current_op_ij = (float('inf'), float('inf'))
254
+ for opi, op in enumerate(progress['ops']):
255
+ if op['title'] not in titles:
256
+ titles.append(op['title'])
257
+ if op['title'] not in subtitles:
258
+ subtitles[op['title']] = []
259
+ if op['subtitle'] not in subtitles[op['title']]:
260
+ subtitles[op['title']].append(op['subtitle'])
261
+
262
+ if opi == progress['cur_op_i']:
263
+ current_op_ij = (len(titles), len(subtitles[op['title']]))
264
+
265
+ for i, title in enumerate(titles, 1):
266
+ progress_md += f"## {['一', '二', '三', '四', '五', '六', '七', '八', '九', '十'][i-1]}、{title}\n"
267
+ for j, subtitle in enumerate(subtitles[title], 1):
268
+ if i < current_op_ij[0] or (i == current_op_ij[0] and j < current_op_ij[1]):
269
+ progress_md += f"### {j}、{subtitle} ✓\n"
270
+ elif i == current_op_ij[0] and j == current_op_ij[1]:
271
+ progress_md += f"### {j}、{subtitle} {'.' * (int(time.time()) % 4)}\n"
272
+ else:
273
+ progress_md += f"### {j}、{subtitle}\n"
274
+
275
+ progress_md += "\n"
276
+
277
+ progress_md += "---\n"
278
+ # TODO: 考虑只放当前进度
279
+
280
+ return gr.Markdown(progress_md)
281
+
282
+
283
+ def create_text_md(writer):
284
+ current_w_name = writer['current_w']
285
+ current_w = writer[current_w_name]
286
+ apply_chunks = current_w['apply_chunks']
287
+
288
+ match current_w_name:
289
+ case 'draft_w':
290
+ column_names = ['剧情', '正文', '修正稿']
291
+ case 'outline_w':
292
+ column_names = ['小说简介', '大纲', '修正稿']
293
+ case 'chapters_w':
294
+ column_names = ['大纲', '剧情', '修正稿']
295
+ case _:
296
+ raise Exception('当前状态不正确')
297
+
298
+ xy_pairs = current_w['xy_pairs']
299
+ if apply_chunks:
300
+ table = [[*e, ''] for e in xy_pairs]
301
+ occupied_rows = [False] * len(table)
302
+ for chunk, key, text in apply_chunks:
303
+ if not isinstance(chunk, Chunk):
304
+ chunk = Chunk(**chunk)
305
+ assert key == 'y_chunk'
306
+ pair_span = chunk.text_source_slice
307
+ if any(occupied_rows[pair_span]):
308
+ raise Exception('apply_chunks中存在重叠的pair_span')
309
+ occupied_rows[pair_span] = [True] * (pair_span.stop - pair_span.start)
310
+ table[pair_span] = [[chunk.x_chunk, chunk.y_chunk, text], ] + [None] * (pair_span.stop - pair_span.start - 1)
311
+ table = [e for e in table if e is not None]
312
+ if not any(e[1] for e in table):
313
+ column_names = column_names[:2]
314
+ column_names[1] = column_names[1] + '(待接受)'
315
+ table = [[e[0], e[2]] for e in table]
316
+ md = create_comparison_table(table, column_names=column_names)
317
+ else:
318
+ if writer_x_is_empty(writer, current_w_name):
319
+ tip_x = '从下方示例中选择一个创意用于创作小说。'
320
+ tip_y = '选择创意后,点击创作大纲。更详细的操作请参考使用指南。'
321
+ if not xy_pairs[0][0].strip():
322
+ xy_pairs = [[tip_x, tip_y]]
323
+ else:
324
+ xy_pairs = [[xy_pairs[0][0], tip_y]]
325
+
326
+ md = create_comparison_table(xy_pairs, column_names=column_names[:2])
327
+
328
+ if len(md) < 400:
329
+ height = '200px'
330
+ else:
331
+ height = '600px'
332
+ return gr.Markdown(md, height=height)
333
+
core/outline_writer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.parser_utils import parse_chapters
2
+ from core.writer_utils import KeyPointMsg
3
+ from core.writer import Writer
4
+
5
+ from prompts.创作章节.prompt import main as prompt_outline
6
+ from prompts.提炼.prompt import main as prompt_summary
7
+
8
+ class OutlineWriter(Writer):
9
+ def __init__(self, xy_pairs, global_context, model=None, sub_model=None, x_chunk_length=2_000, y_chunk_length=2_000, max_thread_num=5):
10
+ super().__init__(xy_pairs, global_context, model, sub_model, x_chunk_length=x_chunk_length, y_chunk_length=y_chunk_length, max_thread_num=max_thread_num)
11
+
12
+ def write(self, user_prompt, pair_span=None):
13
+ target_chunk = self.get_chunk(pair_span=pair_span)
14
+
15
+ if not self.global_context.get("summary", ''):
16
+ raise Exception("需要提供小说简介。")
17
+
18
+ if not target_chunk.y_chunk.strip():
19
+ if not self.y.strip():
20
+ chunks = [target_chunk, ]
21
+ else:
22
+ raise Exception("选中进行创作的内容不能为空,考虑随便填写一些占位的字。")
23
+ else:
24
+ chunks = self.get_chunks(pair_span)
25
+
26
+ new_chunks = yield from self.batch_yield(
27
+ [self.write_text(e, prompt_outline, user_prompt) for e in chunks],
28
+ chunks, prompt_name='创作文本')
29
+
30
+ results = yield from self.batch_split_chapters(new_chunks)
31
+
32
+ new_chunks2 = [e[0] for e in results]
33
+
34
+ self.apply_chunks(chunks, new_chunks2)
35
+
36
+ def split_chapters(self, chunk):
37
+ if False: yield # 将此函数变为生成器函数
38
+
39
+ assert chunk.x_chunk == '', 'chunk.x_chunk不为空'
40
+ chapter_titles, chapter_contents = parse_chapters(chunk.y_chunk)
41
+ new_xy_pairs = self.construct_xy_pairs(chapter_titles, chapter_contents)
42
+
43
+ return chunk.edit(text_pairs=new_xy_pairs), True, ''
44
+
45
+ def construct_xy_pairs(self, chapter_titles, chapter_contents):
46
+ return [('', f"{title[0]} {title[1]}\n{content}") for title, content in zip(chapter_titles, chapter_contents)]
47
+
48
+ def batch_split_chapters(self, chunks):
49
+ results = yield from self.batch_yield(
50
+ [self.split_chapters(e) for e in chunks], chunks, prompt_name='划分章节')
51
+ return results
52
+
53
+ def summary(self):
54
+ target_chunk = self.get_chunk(pair_span=(0, len(self.xy_pairs)))
55
+ if not target_chunk.y_chunk:
56
+ raise Exception("没有章节需要总结。")
57
+ if len(target_chunk.y_chunk) <= 5:
58
+ raise Exception("需要总结的章节不能少于5个字。")
59
+
60
+ if len(target_chunk.y_chunk) > 2000:
61
+ y = self._truncate_chunk(target_chunk.y_chunk)
62
+ else:
63
+ y = target_chunk.y_chunk
64
+
65
+ result = yield from prompt_summary(self.model, "提炼大纲", y=y)
66
+
67
+ self.global_context['outline'] = result['text']
68
+
69
+ def get_model(self):
70
+ return self.model
71
+
72
+ def get_sub_model(self):
73
+ return self.sub_model
74
+
75
+ def _truncate_chunk(self, text, chunk_size=100, keep_chunks=20):
76
+ """Truncate chunk content by keeping evenly spaced sections"""
77
+ if len(text) <= 2000:
78
+ return text
79
+
80
+ # Split into chunks of chunk_size
81
+ chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
82
+
83
+ # Select evenly spaced chunks
84
+ step = len(chunks) // keep_chunks
85
+ selected_chunks = chunks[::step][:keep_chunks]
86
+ new_content = '...'.join(selected_chunks)
87
+ return new_content
88
+
core/parser_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def parse_chapters(content):
5
+ # Single pattern to capture: full chapter number (第X章), title, and content
6
+ pattern = r'(第[零一二三四五六七八九十百千万亿0123456789.-]+章)([^\n]*)\n*([\s\S]*?)(?=第[零一二三四五六七八九十百千万亿0123456789.-]+章|$)'
7
+ matches = re.findall(pattern, content)
8
+
9
+ # Unpack directly into separate lists using zip
10
+ chapter_titles, title_names, chapter_contents = zip(*[
11
+ (index, name.strip(), content.strip())
12
+ for index, name, content in matches
13
+ ]) if matches else ([], [], [])
14
+
15
+ return list(zip(chapter_titles, title_names)), list(chapter_contents)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ test = """
20
+ 第1-1章 出世
21
+ 主角张小凡出身贫寒,因天赋异禀被青云门收为弟子,开始修仙之路。
22
+
23
+ 第2.1章 初入青云
24
+
25
+ 张小凡在青云门中结识师兄弟,学习基础法术,逐渐适应修仙生活。
26
+
27
+ 第3章 灵气初现
28
+ 张小凡在一次意外中感受到天地灵气,修为有所提升。
29
+ """
30
+
31
+ results = parse_chapters(test)
32
+ print()
core/plot_writer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.writer_utils import KeyPointMsg
2
+ from core.writer import Writer
3
+
4
+ from prompts.创作剧情.prompt import main as prompt_plot
5
+ from prompts.提炼.prompt import main as prompt_summary
6
+
7
+ class PlotWriter(Writer):
8
+ def __init__(self, xy_pairs, global_context, model=None, sub_model=None, x_chunk_length=200, y_chunk_length=1000, max_thread_num=5):
9
+ super().__init__(xy_pairs, global_context, model, sub_model, x_chunk_length=x_chunk_length, y_chunk_length=y_chunk_length, max_thread_num=max_thread_num)
10
+
11
+ def write(self, user_prompt, pair_span=None):
12
+ target_chunk = self.get_chunk(pair_span=pair_span)
13
+
14
+ if not self.global_context.get("chapter", ''):
15
+ raise Exception("需要提供章节内容。")
16
+
17
+ if not target_chunk.y_chunk.strip():
18
+ if not self.y.strip():
19
+ chunks = [target_chunk, ]
20
+ else:
21
+ raise Exception("选中进行创作的内容不能为空,考虑随便填写一些占位的字。")
22
+ else:
23
+ chunks = self.get_chunks(pair_span)
24
+
25
+ new_chunks = yield from self.batch_yield(
26
+ [self.write_text(e, prompt_plot, user_prompt) for e in chunks],
27
+ chunks, prompt_name='创作文本')
28
+
29
+ results = yield from self.batch_map_text(new_chunks)
30
+ new_chunks2 = [e[0] for e in results]
31
+
32
+ self.apply_chunks(chunks, new_chunks2)
33
+
34
+ def summary(self):
35
+ target_chunk = self.get_chunk(pair_span=(0, len(self.xy_pairs)))
36
+ if not target_chunk.y_chunk:
37
+ raise Exception("没有剧情需要总结。")
38
+ if len(target_chunk.y_chunk) <= 5:
39
+ raise Exception("需要总结的剧情不能少于5个字。")
40
+
41
+ result = yield from prompt_summary(self.model, "提炼章节", y=target_chunk.y_chunk)
42
+
43
+ self.global_context['chapter'] = result['text']
44
+
45
+ def get_model(self):
46
+ return self.model
47
+
48
+ def get_sub_model(self):
49
+ return self.sub_model
core/summary_novel.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from core.draft_writer import DraftWriter
3
+ from core.plot_writer import PlotWriter
4
+ from core.outline_writer import OutlineWriter
5
+ from core.writer_utils import KeyPointMsg
6
+
7
+
8
+
9
+ def summary_draft(model, sub_model, chapter_title, chapter_text):
10
+ xy_pairs = [('', chapter_text)]
11
+
12
+ dw = DraftWriter(xy_pairs, {}, model=model, sub_model=sub_model, x_chunk_length=500, y_chunk_length=1000)
13
+ dw.max_thread_num = 1 # 每章的处理只采用一个线程
14
+
15
+ generator = dw.summary(pair_span=(0, len(xy_pairs)))
16
+
17
+ kp_msg_title = ''
18
+ for kp_msg in generator:
19
+ if isinstance(kp_msg, KeyPointMsg):
20
+ # 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
21
+ kp_msg_title = kp_msg.prompt_name
22
+ continue
23
+ else:
24
+ chunk_list = kp_msg
25
+
26
+ current_cost = 0
27
+ currency_symbol = ''
28
+ finished_chunk_num = 0
29
+ chars_num = 0
30
+ model = None
31
+ for e in chunk_list:
32
+ if e is None: continue
33
+ finished_chunk_num += 1
34
+ output, chunk = e
35
+ if output is None: continue # 说明是map_text, 在第一次next就stop iteration了
36
+ current_cost += output['response_msgs'].cost
37
+ currency_symbol = output['response_msgs'].currency_symbol
38
+ chars_num += len(output['response_msgs'].response)
39
+ model = output['response_msgs'].model
40
+
41
+ yield dict(
42
+ progress_msg=f"[{chapter_title}] 提炼章节剧情 {kp_msg_title} 进度:{finished_chunk_num}/{len(chunk_list)} 已创作字符:{chars_num} 已花费:{current_cost:.4f}{currency_symbol}",
43
+ chars_num=chars_num,
44
+ current_cost=current_cost,
45
+ currency_symbol=currency_symbol,
46
+ model=model
47
+ )
48
+
49
+ return dw
50
+
51
+
52
+ def summary_plot(model, sub_model, chapter_title, chapter_plot):
53
+ xy_pairs = [('', chapter_plot)]
54
+
55
+ pw = PlotWriter(xy_pairs, {}, model=model, sub_model=sub_model, x_chunk_length=500, y_chunk_length=1000)
56
+
57
+ generator = pw.summary()
58
+
59
+ for output in generator:
60
+ current_cost = output['response_msgs'].cost
61
+ currency_symbol = output['response_msgs'].currency_symbol
62
+ chars_num = len(output['response_msgs'].response)
63
+ yield dict(
64
+ progress_msg=f"[{chapter_title}] 提炼章节大纲 已创作字符:{chars_num} 已花费:{current_cost:.4f}{currency_symbol}",
65
+ chars_num=chars_num,
66
+ current_cost=current_cost,
67
+ currency_symbol=currency_symbol,
68
+ model=output['response_msgs'].model
69
+ )
70
+
71
+ return pw
72
+
73
+ def summary_chapters(model, sub_model, title, chapter_titles, chapter_content):
74
+ ow = OutlineWriter([('', '')], {}, model=model, sub_model=sub_model, x_chunk_length=500, y_chunk_length=1000)
75
+ ow.xy_pairs = ow.construct_xy_pairs(chapter_titles, chapter_content)
76
+
77
+ generator = ow.summary()
78
+
79
+ for output in generator:
80
+ current_cost = output['response_msgs'].cost
81
+ currency_symbol = output['response_msgs'].currency_symbol
82
+ chars_num = len(output['response_msgs'].response)
83
+ yield dict(
84
+ progress_msg=f"[{title}] 提炼全书大纲 已创作字符:{chars_num} 已花费:{current_cost:.4f}{currency_symbol}",
85
+ chars_num=chars_num,
86
+ current_cost=current_cost,
87
+ currency_symbol=currency_symbol,
88
+ model=output['response_msgs'].model
89
+ )
90
+
91
+ return ow
92
+
93
+
94
+
core/writer.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import bisect
4
+ from dataclasses import asdict, dataclass
5
+
6
+ from llm_api import ModelConfig
7
+ from prompts.对齐剧情和正文 import prompt as match_plot_and_text
8
+ from prompts.审阅.prompt import main as prompt_review
9
+ from core.writer_utils import split_text_into_chunks, detect_max_edit_span, run_yield_func
10
+ from core.writer_utils import KeyPointMsg
11
+ from core.diff_utils import get_chunk_changes
12
+
13
+
14
+ class Chunk(dict):
15
+ def __init__(self, chunk_pairs: tuple[tuple[str, str, str]], source_slice: tuple[int, int], text_slice: tuple[int, int]):
16
+ super().__init__()
17
+ self['chunk_pairs'] = tuple(chunk_pairs)
18
+
19
+ if isinstance(source_slice, slice):
20
+ source_slice = (source_slice.start, source_slice.stop)
21
+ self['source_slice'] = source_slice
22
+
23
+ if isinstance(text_slice, slice):
24
+ text_slice = (text_slice.start, text_slice.stop)
25
+ assert text_slice[1] is None or text_slice[1] < 0, 'text_slice end must be None or negative'
26
+ self['text_slice'] = text_slice
27
+
28
+ def edit(self, x_chunk=None, y_chunk=None, text_pairs=None):
29
+ if x_chunk is not None:
30
+ text_pairs = [(x_chunk, self.y_chunk), ]
31
+ elif y_chunk is not None:
32
+ text_pairs = [(self.x_chunk, y_chunk), ]
33
+ else:
34
+ text_pairs = text_pairs
35
+
36
+ chunk_pairs = list(self['chunk_pairs'])
37
+ chunk_pairs[self.text_slice] = list(text_pairs)
38
+
39
+ return Chunk(chunk_pairs=tuple(chunk_pairs), source_slice=self.source_slice, text_slice=self.text_slice)
40
+
41
+ @property
42
+ def source_slice(self) -> slice:
43
+ return slice(*self['source_slice'])
44
+
45
+ @property
46
+ def chunk_pairs(self) -> tuple[tuple[str, str]]:
47
+ return self['chunk_pairs']
48
+
49
+ @property
50
+ def text_slice(self) -> slice:
51
+ return slice(*self['text_slice'])
52
+
53
+ @property
54
+ def text_source_slice(self) -> slice:
55
+ source_start = self.source_slice.start + self.text_slice.start
56
+ source_stop = self.source_slice.stop + (self.text_slice.stop or 0)
57
+ return slice(source_start, source_stop)
58
+
59
+ @property
60
+ def text_pairs(self) -> tuple[tuple[str, str]]:
61
+ return self.chunk_pairs[self.text_slice]
62
+
63
+ @property
64
+ def x_chunk(self) -> str:
65
+ return ''.join(pair[0] for pair in self.text_pairs)
66
+
67
+ @property
68
+ def y_chunk(self) -> str:
69
+ return ''.join(pair[1] for pair in self.text_pairs)
70
+
71
+ @property
72
+ def x_chunk_len(self) -> int:
73
+ return sum(len(pair[0]) for pair in self.text_pairs)
74
+
75
+ @property
76
+ def y_chunk_len(self) -> int:
77
+ return sum(len(pair[1]) for pair in self.text_pairs)
78
+
79
+ @property
80
+ def x_chunk_context(self) -> str:
81
+ return ''.join(pair[0] for pair in self.chunk_pairs)
82
+
83
+ @property
84
+ def y_chunk_context(self) -> str:
85
+ return ''.join(pair[1] for pair in self.chunk_pairs)
86
+
87
+ @property
88
+ def x_chunk_context_len(self) -> int:
89
+ return sum(len(pair[0]) for pair in self.chunk_pairs)
90
+
91
+ @property
92
+ def y_chunk_context_len(self) -> int:
93
+ return sum(len(pair[1]) for pair in self.chunk_pairs)
94
+
95
+
96
+ class Writer:
97
+ def __init__(self, xy_pairs, global_context=None, model:ModelConfig=None, sub_model:ModelConfig=None, x_chunk_length=1000, y_chunk_length=1000, max_thread_num=5):
98
+ self.xy_pairs = xy_pairs
99
+ self.global_context = global_context or {}
100
+
101
+ self.model = model
102
+ self.sub_model = sub_model
103
+
104
+ self.x_chunk_length = x_chunk_length
105
+ self.y_chunk_length = y_chunk_length
106
+
107
+ # x_chunk_length是指一次prompt调用时输入的x长度(由batch_map函数控制), 此参数会影响到映射到y的扩写率(即:LLM的输出窗口长度/x_chunk_length)
108
+ # 同时,x_chunk_length会影响到map的chunk大小,map的pair大小主要由x_chunk_length决定(具体来说,由update_map函数控制,为x_chunk_length//2)
109
+ # y_chunk_length对pair大小的影响较少(因为映射是一对多)
110
+
111
+ self.max_thread_num = max_thread_num # 使得可以单独控制某个chunk变量的线程数,这在同时运行多个Writer变量时有用
112
+
113
+ @property
114
+ def x(self): # TODO: 考虑x经常访问的情况
115
+ return ''.join(pair[0] for pair in self.xy_pairs)
116
+
117
+ @property
118
+ def y(self):
119
+ return ''.join(pair[1] for pair in self.xy_pairs)
120
+
121
+ @property
122
+ def x_len(self):
123
+ return sum(len(pair[0]) for pair in self.xy_pairs)
124
+
125
+ @property
126
+ def y_len(self):
127
+ return sum(len(pair[1]) for pair in self.xy_pairs)
128
+
129
+ def get_model(self):
130
+ return self.model
131
+
132
+ def get_sub_model(self):
133
+ return self.sub_model
134
+
135
+ def count_span_length(self, span):
136
+ pairs = self.xy_pairs[span[0]:span[1]]
137
+ return sum(len(pair[0]) for pair in pairs), sum(len(pair[1]) for pair in pairs)
138
+
139
+ def align_span(self, x_span=None, y_span=None):
140
+ if x_span is None and y_span is None:
141
+ raise ValueError("Either x_span or y_span must be provided")
142
+
143
+ if x_span is not None and y_span is not None:
144
+ raise ValueError("Only one of x_span or y_span should be provided")
145
+
146
+ is_x = x_span is not None
147
+ z_span = x_span if is_x else y_span
148
+ cumsum_z = np.cumsum([0] + [len(pair[0 if is_x else 1]) for pair in self.xy_pairs]).tolist()
149
+
150
+ l, r = z_span
151
+ start_chunk = bisect.bisect_right(cumsum_z, l) - 1
152
+ end_chunk = bisect.bisect_left(cumsum_z, r)
153
+
154
+ aligned_l = cumsum_z[start_chunk]
155
+ aligned_r = cumsum_z[end_chunk]
156
+
157
+ aligned_span = (aligned_l, aligned_r)
158
+ pair_span = (start_chunk, end_chunk)
159
+
160
+ # Add assertions to verify the correctness of the output
161
+ assert aligned_l <= l < aligned_r, "aligned_span does not properly contain the start of the input span"
162
+ assert aligned_l < r <= aligned_r, "aligned_span does not properly contain the end of the input span"
163
+ assert 0 <= start_chunk < end_chunk <= len(self.xy_pairs), "pair_span is out of bounds"
164
+ assert sum(len(pair[0 if is_x else 1]) for pair in self.xy_pairs[start_chunk:end_chunk]) == aligned_r - aligned_l, "aligned_span and pair_span do not match"
165
+
166
+ return aligned_span, pair_span
167
+
168
+ def get_chunk(self, pair_span=None, x_span=None, y_span=None, context_length=0, smooth=True):
169
+ if sum(x is not None for x in [pair_span, x_span, y_span]) != 1:
170
+ raise ValueError("Exactly one of pair_span, x_span, or y_span must be provided")
171
+
172
+ assert pair_span is None or (pair_span[0] >= 0 and pair_span[1] <= len(self.xy_pairs)), "pair_span is out of bounds"
173
+
174
+ is_x = x_span is not None
175
+ is_pair = pair_span is not None
176
+
177
+ if is_pair:
178
+ context_pair_span = (
179
+ max(0, pair_span[0] - context_length),
180
+ min(len(self.xy_pairs), pair_span[1] + context_length)
181
+ )
182
+ else:
183
+ assert smooth, "smooth must be True"
184
+ span = x_span if is_x else y_span
185
+ if smooth:
186
+ span, pair_span = self.align_span(x_span=span if is_x else None, y_span=span if not is_x else None)
187
+
188
+ context_span = (
189
+ max(0, span[0] - context_length),
190
+ min(self.x_len if is_x else self.y_len, span[1] + context_length)
191
+ )
192
+
193
+ context_span, context_pair_span = self.align_span(x_span=context_span if is_x else None, y_span=context_span if not is_x else None)
194
+
195
+ chunk_pairs = self.xy_pairs[context_pair_span[0]:context_pair_span[1]]
196
+ source_slice = context_pair_span
197
+ text_slice = (pair_span[0] - context_pair_span[0], pair_span[1] - context_pair_span[1])
198
+ assert text_slice[1] <= 0, "text_slice end must be negative"
199
+ text_slice = (text_slice[0], None if text_slice[1] == 0 else text_slice[1])
200
+
201
+ return Chunk(
202
+ chunk_pairs=chunk_pairs,
203
+ source_slice=source_slice,
204
+ text_slice=text_slice
205
+ )
206
+
207
+ def get_chunk_pair_span(self, chunk: Chunk):
208
+ pair_start, pair_end = chunk.text_source_slice.start, chunk.text_source_slice.stop
209
+ merged_x_chunk = ''.join(p[0] for p in self.xy_pairs[pair_start:pair_end])
210
+ merged_y_chunk = ''.join(p[1] for p in self.xy_pairs[pair_start:pair_end])
211
+ if merged_x_chunk == chunk.x_chunk and merged_y_chunk == chunk.y_chunk:
212
+ return pair_start, pair_end
213
+
214
+ pair_start, pair_end = 0, len(self.xy_pairs)
215
+ x_chunk, y_chunk = chunk.x_chunk, chunk.y_chunk
216
+ for i, (x, y) in enumerate(self.xy_pairs):
217
+ if x_chunk[:50].startswith(x[:50]) and y_chunk[:50].startswith(y[:50]):
218
+ pair_start = i
219
+ break
220
+
221
+ for i in range(pair_start, len(self.xy_pairs)):
222
+ x, y = self.xy_pairs[i]
223
+ if x_chunk[-50:].endswith(x[-50:]) and y_chunk[-50:].endswith(y[-50:]):
224
+ pair_end = i + 1
225
+ break
226
+
227
+ # Verify the pair_span
228
+ merged_x_chunk = ''.join(p[0] for p in self.xy_pairs[pair_start:pair_end])
229
+ merged_y_chunk = ''.join(p[1] for p in self.xy_pairs[pair_start:pair_end])
230
+ assert x_chunk == merged_x_chunk and y_chunk == merged_y_chunk, "Chunk mismatch"
231
+
232
+ return (pair_start, pair_end)
233
+
234
+ def apply_chunks(self, chunks: list[Chunk], new_chunks: list[Chunk]):
235
+ occupied_pair_span = [False] * len(self.xy_pairs)
236
+ pair_span_list = [self.get_chunk_pair_span(e) for e in chunks]
237
+ for pair_span in pair_span_list:
238
+ assert not any(occupied_pair_span[pair_span[0]:pair_span[1]]), "Chunk overlap"
239
+ occupied_pair_span[pair_span[0]:pair_span[1]] = [True] * (pair_span[1] - pair_span[0])
240
+ # TODO: 这里可以验证occupied_pair_span是否全被占据
241
+ new_pairs_list = [e.text_pairs for e in new_chunks]
242
+
243
+ sorted_spans_with_new_pairs = sorted(
244
+ zip(pair_span_list, new_pairs_list),
245
+ key=lambda x: x[0][0],
246
+ reverse=True
247
+ )
248
+
249
+ for (start, end), new_pairs in sorted_spans_with_new_pairs:
250
+ self.xy_pairs[start:end] = new_pairs
251
+
252
+ def get_chunks(self, pair_span=None, chunk_length_ratio=1, context_length_ratio=1, offset_ratio=0):
253
+ pair_span = pair_span or (0, len(self.xy_pairs))
254
+ chunk_length = self.x_chunk_length * chunk_length_ratio, self.y_chunk_length * chunk_length_ratio
255
+ context_length = self.x_chunk_length//2 * context_length_ratio, self.y_chunk_length//2 * context_length_ratio
256
+
257
+ if 0 < offset_ratio < 1:
258
+ offset_ratio = int(chunk_length[0] * offset_ratio), int(chunk_length[1] * offset_ratio)
259
+
260
+ # Generate chunks
261
+ chunks = []
262
+ start = pair_span[0]
263
+ cstart = self.count_span_length((0, start)) # char_start
264
+ max_cend = self.count_span_length((0, pair_span[1])) # char_end
265
+ while start < pair_span[1]:
266
+ if offset_ratio != 0:
267
+ cend = cstart[0] + offset_ratio[0], cstart[1] + offset_ratio[1]
268
+ offset_ratio = 0
269
+ else:
270
+ cend = cstart[0] + int(chunk_length[0] * 0.8), cstart[1] + int(chunk_length[1] * 0.8) # 八二原则,偷个懒,不求最优划分
271
+ cend = min(cend[0], max_cend[0]), min(cend[1], max_cend[1])
272
+
273
+ # 选择非零长度的span来获取chunk
274
+ x_len, y_len = cend[0] - cstart[0], cend[1] - cstart[1]
275
+ if x_len > 0:
276
+ chunk1 = self.get_chunk(x_span=(cstart[0], cend[0]), context_length=context_length[0])
277
+ if y_len > 0:
278
+ chunk2 = self.get_chunk(y_span=(cstart[1], cend[1]), context_length=context_length[1])
279
+
280
+ if x_len > 0 and y_len == 0:
281
+ chunk = chunk1
282
+ elif x_len == 0 and y_len > 0:
283
+ chunk = chunk2
284
+ elif x_len > 0 and y_len > 0:
285
+ # 选其中source_slice更小的chunk
286
+ chunk = chunk1 if chunk1.source_slice.stop - chunk1.source_slice.start < chunk2.source_slice.stop - chunk2.source_slice.start else chunk2
287
+ else:
288
+ raise ValueError("Both x_span and y_span have zero length")
289
+
290
+ # assert chunk.x_chunk_context_len <= self.x_chunk_length * 2 and chunk.y_chunk_context_len <= self.y_chunk_length * 2, \
291
+ # "无法获取到一个足够短的区块,请调整区块长度或窗口长度!"
292
+
293
+ chunks.append(chunk)
294
+ start = chunk.text_source_slice.stop
295
+ cstart = self.count_span_length((0, start))
296
+
297
+ return chunks
298
+
299
+ # TODO: batch_yield 可以考虑输入生成器,而不是函数及参数
300
+ def batch_yield(self, generators, chunks, prompt_name=None):
301
+ # TODO: 后续考虑只输出new_chunks, 不必重复输出chunks
302
+
303
+ # Process all pairs with the prompt and yield intermediate results
304
+ results = [None] * len(generators)
305
+ yields = [None] * len(generators)
306
+ finished = [False] * len(generators)
307
+ first_iter_flag = True
308
+ while True:
309
+ co_num = 0
310
+ for i, gen in enumerate(generators):
311
+ if finished[i]:
312
+ continue
313
+
314
+ try:
315
+ co_num += 1
316
+ yield_value = next(gen)
317
+ yields[i] = (yield_value, chunks[i]) # TODO: yield 带上chunk是为了配合前端
318
+ except StopIteration as e:
319
+ results[i] = e.value
320
+ finished[i] = True
321
+ if yields[i] is None: yields[i] = (None, chunks[i])
322
+
323
+ if co_num >= self.max_thread_num:
324
+ break
325
+
326
+ if all(finished):
327
+ break
328
+
329
+ if first_iter_flag and prompt_name is not None:
330
+ yield (kp_msg := KeyPointMsg(prompt_name=prompt_name))
331
+ first_iter_flag = False
332
+
333
+ yield yields # 如果是yield的值,那必定为tuple
334
+
335
+ if not first_iter_flag and prompt_name is not None:
336
+ yield kp_msg.set_finished()
337
+
338
+ return results
339
+
340
+ # 临时函数,用于配合前端,返回一个更改,对self施加该更改可以变为cur
341
+ def diff_to(self, cur, pair_span=None):
342
+ if pair_span is None:
343
+ pair_span = (0, len(self.xy_pairs))
344
+
345
+ if self.count_span_length(pair_span)[0] == 0:
346
+ # 2.1版本中,章节和剧情的创作不参考x
347
+ pair_span2 = (0 + pair_span[0], len(cur.xy_pairs) - (len(self.xy_pairs) - pair_span[1]))
348
+ y_list = [e[1] for e in self.xy_pairs[pair_span[0]:pair_span[1]]]
349
+ y2_list =[e[1] for e in cur.xy_pairs[pair_span2[0]:pair_span2[1]]]
350
+
351
+ y_list += ['',] * max(len(y2_list) - len(y_list), 0)
352
+ y2_list += ['',] * max(len(y_list) - len(y2_list), 0)
353
+
354
+ data_chunks = [('', y, y2) for y, y2 in zip(y_list, y2_list)]
355
+
356
+ return data_chunks
357
+
358
+ pre_pointer = 0, 1
359
+ cur_pointer = 0, 1
360
+
361
+ cum_sum_pre = np.cumsum([0] + [len(pair[0]) for pair in self.xy_pairs])
362
+ cum_sum_cur = np.cumsum([0] + [len(pair[0]) for pair in cur.xy_pairs])
363
+
364
+ apply_chunks = []
365
+
366
+ while pre_pointer[1] <= len(self.xy_pairs) and cur_pointer[1] <= len(cur.xy_pairs):
367
+ if cum_sum_pre[pre_pointer[1]] - cum_sum_pre[pre_pointer[0]] == cum_sum_cur[cur_pointer[1]] - cum_sum_cur[cur_pointer[0]]:
368
+ chunk = self.get_chunk(pair_span=pre_pointer)
369
+ value = "".join(pair[1] for pair in cur.xy_pairs[cur_pointer[0]:cur_pointer[1]])
370
+ apply_chunks.append((chunk, 'y_chunk', value))
371
+
372
+ pre_pointer = pre_pointer[1], pre_pointer[1] + 1
373
+ cur_pointer = cur_pointer[1], cur_pointer[1] + 1
374
+ elif cum_sum_pre[pre_pointer[1]] - cum_sum_pre[pre_pointer[0]] < cum_sum_cur[cur_pointer[1]] - cum_sum_cur[cur_pointer[0]]:
375
+ pre_pointer = pre_pointer[0], pre_pointer[1] + 1
376
+ else:
377
+ cur_pointer = cur_pointer[0], cur_pointer[1] + 1
378
+
379
+ assert pre_pointer[1] == len(self.xy_pairs) + 1 and cur_pointer[1] == len(cur.xy_pairs) + 1
380
+
381
+ filtered_apply_chunks = []
382
+ for e in apply_chunks:
383
+ text_source_slice = e[0].text_source_slice
384
+ if text_source_slice.start >= pair_span[0] and text_source_slice.stop <= pair_span[1]:
385
+ filtered_apply_chunks.append(e)
386
+
387
+ data_chunks = []
388
+ for chunk, key, value in filtered_apply_chunks:
389
+ data_chunks.append((chunk.x_chunk, chunk.y_chunk, value))
390
+
391
+ return data_chunks
392
+
393
+ # 临时函数,用于配合前端
394
+ def apply_chunk(self, chunk:Chunk, key, value):
395
+ if not isinstance(chunk, Chunk):
396
+ chunk = Chunk(**chunk)
397
+ new_chunk = chunk.edit(**{key: value})
398
+ self.apply_chunks([chunk], [new_chunk])
399
+
400
+ def write_text(self, chunk:Chunk, prompt_main, user_prompt_text, input_keys=None, model=None):
401
+ chunk2prompt_key = {
402
+ 'x_chunk': 'x',
403
+ 'y_chunk': 'y',
404
+ 'x_chunk_context': 'context_x',
405
+ 'y_chunk_context': 'context_y'
406
+ }
407
+
408
+
409
+ if input_keys is not None:
410
+ prompt_kwargs = {k: getattr(chunk, k) for k in input_keys}
411
+ assert all(prompt_kwargs.values()), "Missing required context keys"
412
+ else:
413
+ prompt_kwargs = {k: getattr(chunk, k) for k in chunk2prompt_key.keys()}
414
+
415
+ prompt_kwargs = {chunk2prompt_key.get(k, k): v for k, v in prompt_kwargs.items()}
416
+
417
+ prompt_kwargs.update(self.global_context) # prompt_kwargs会把所有的信息都带上,至于要用哪些由prompt决定
418
+
419
+ result = yield from prompt_main(
420
+ model=model or self.get_model(),
421
+ user_prompt=user_prompt_text,
422
+ **prompt_kwargs
423
+ )
424
+
425
+ # 为了在V2.2版本兼容summary_prompt, 后续text_key这种设计会舍弃
426
+ update_dict = {}
427
+ if 'text_key' in result:
428
+ update_dict[result['text_key']] = result['text']
429
+ else:
430
+ update_dict['y_chunk'] = result['text']
431
+
432
+ return chunk.edit(**update_dict)
433
+
434
+ # 目前review(审阅)的评分机制暂未实装
435
+ def review_text(self, chunk:Chunk, prompt_name, model=None):
436
+ result = yield from prompt_review(
437
+ model=model or self.get_model(),
438
+ prompt_name=prompt_name,
439
+ y=chunk.y_chunk
440
+ )
441
+
442
+ return result['text']
443
+
444
+ def map_text_wo_llm(self, chunk:Chunk):
445
+ # 该函数尝试不用LLM进行映射,目标是保证chunk.pairs中每个pair的长度合适,如果长了,进行划分,如果无法划分,报错
446
+ new_xy_pairs = []
447
+ for x, y in chunk.text_pairs:
448
+ if x.strip() and not y.strip():
449
+ x_pairs = split_text_into_chunks(x, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5)
450
+ new_xy_pairs.extend([(x_pair, y) for x_pair in x_pairs])
451
+ elif not x.strip() and y.strip():
452
+ y_pairs = split_text_into_chunks(y, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5)
453
+ new_xy_pairs.extend([(x, y_pair) for y_pair in y_pairs])
454
+ else:
455
+ if len(x) > self.x_chunk_length or len(y) > self.y_chunk_length:
456
+ raise ValueError("窗口太小或段落太长!考虑选择更大的窗口长度或手动分段。")
457
+ new_xy_pairs.append((x, y))
458
+
459
+ return chunk.edit(text_pairs=new_xy_pairs)
460
+
461
+ def map_text(self, chunk:Chunk):
462
+ # TODO: map会检查映射的内容是否大致匹配,是否有错误映射到context的情况
463
+
464
+ if chunk.x_chunk.strip():
465
+ x_pairs = split_text_into_chunks(chunk.x_chunk, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
466
+ assert len(x_pairs) >= len(chunk.text_pairs), "未知错误!合并所有区块后再分区块,结果更少?"
467
+ if len(x_pairs) == len(chunk.text_pairs):
468
+ return chunk, True, ''
469
+ else:
470
+ # 这说明y的创作是不参照x的,而是参照global_context
471
+ y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
472
+ new_xy_pairs = [('', y) for y in y_pairs]
473
+ return chunk.edit(text_pairs=new_xy_pairs), True, ''
474
+
475
+ try:
476
+ y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=len(x_pairs), min_chunk_size=5, max_chunk_n=20)
477
+ except Exception as e:
478
+ # 如果y_chunk不能找到更多的区块划分,干脆让x_chunk划分更少的区块
479
+ y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
480
+ x_pairs = split_text_into_chunks(chunk.x_chunk, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=int(0.8 * len(y_pairs)))
481
+
482
+ # TODO: 这是因为目前映射Prompt的设计需要x数量小于y,后续会对Prompt进行改进
483
+
484
+ try:
485
+ gen = match_plot_and_text.main(
486
+ model=self.get_sub_model(),
487
+ plot_chunks=x_pairs,
488
+ text_chunks=y_pairs
489
+ )
490
+ while True:
491
+ yield next(gen)
492
+ except StopIteration as e:
493
+ output = e.value
494
+
495
+ x2y = output['plot2text']
496
+ new_xy_pairs = []
497
+ for xi_list, yi_list in x2y:
498
+ xl, xr = xi_list[0], xi_list[-1]
499
+ new_xy_pairs.append(("".join(x_pairs[xl:xr+1]), "".join(y_pairs[i] for i in yi_list)))
500
+
501
+ new_chunk = chunk.edit(text_pairs=new_xy_pairs)
502
+ return new_chunk, True, ''
503
+
504
+ def batch_map_text(self, chunks):
505
+ results = yield from self.batch_yield(
506
+ [self.map_text(e) for e in chunks], chunks, prompt_name='映射文本')
507
+ return results
508
+
509
+ def batch_write_apply_text(self, chunks, prompt_main, user_prompt_text):
510
+ new_chunks = yield from self.batch_yield(
511
+ [self.write_text(e, prompt_main, user_prompt_text) for e in chunks],
512
+ chunks, prompt_name='创作文本')
513
+
514
+ results = yield from self.batch_map_text(new_chunks)
515
+ new_chunks2 = [e[0] for e in results]
516
+
517
+ self.apply_chunks(chunks, new_chunks2)
518
+
519
+ def batch_review_write_apply_text(self, chunks, write_prompt_main, review_prompt_name):
520
+ reviews = yield from self.batch_yield(
521
+ [self.review_text(e, review_prompt_name) for e in chunks],
522
+ chunks, prompt_name='审阅文本')
523
+
524
+ rewrite_instrustion = "\n\n根据审阅意见,重新创作,如果审阅意见表示无需改动,则保持原样输出。"
525
+
526
+ new_chunks = yield from self.batch_yield(
527
+ [self.write_text(chunk, write_prompt_main, review + rewrite_instrustion) for chunk, review in zip(chunks, reviews)],
528
+ chunks, prompt_name='创作文本')
529
+
530
+ results = yield from self.batch_map_text(new_chunks)
531
+ new_chunks2 = [e[0] for e in results]
532
+
533
+ self.apply_chunks(chunks, new_chunks2)
core/writer_utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+
3
+ # 定义了用于Wirter yield的数据类型,同时也是前端展示的“关键点”消息
4
+ class KeyPointMsg(dict):
5
+ def __init__(self, title='', subtitle='', prompt_name=''):
6
+ super().__init__()
7
+ if not title and not subtitle and prompt_name:
8
+ pass
9
+ elif title and subtitle and not prompt_name:
10
+ pass
11
+ else:
12
+ raise ValueError('Either title and subtitle or prompt_name must be provided')
13
+
14
+ self.update({
15
+ 'id': str(uuid.uuid4()),
16
+ 'title': title,
17
+ 'subtitle': subtitle,
18
+ 'prompt_name': prompt_name,
19
+ 'finished': False
20
+ })
21
+
22
+ def set_finished(self):
23
+ assert not self['finished'], 'finished flag is already set'
24
+ self['finished'] = True
25
+ return self # 返回self,方便链式调用
26
+
27
+ def is_finished(self):
28
+ return self['finished']
29
+
30
+ def is_prompt(self):
31
+ return bool(self.prompt_name)
32
+
33
+ def is_title(self):
34
+ return bool(self.title)
35
+
36
+ @property
37
+ def id(self):
38
+ return self['id']
39
+
40
+ @property
41
+ def title(self):
42
+ return self['title']
43
+
44
+ @property
45
+ def subtitle(self):
46
+ return self['subtitle']
47
+
48
+ @property
49
+ def prompt_name(self):
50
+ prompt_name = self['prompt_name']
51
+ if len(prompt_name) >= 10:
52
+ return prompt_name[:10] + '...'
53
+ return prompt_name
54
+
55
+
56
+ import re
57
+ from difflib import Differ
58
+
59
+ # 后续考虑采用现成的库实现,目前逻辑过于繁琐,而且太慢了
60
+ def detect_max_edit_span(a, b):
61
+ diff = Differ().compare(a, b)
62
+
63
+ l = 0
64
+ r = 0
65
+ flag_count_l = True
66
+
67
+ for tag in diff:
68
+ if tag.startswith(' '):
69
+ if flag_count_l:
70
+ l += 1
71
+ else:
72
+ r += 1
73
+ else:
74
+ flag_count_l = False
75
+ r = 0
76
+
77
+ return l, -r
78
+
79
+ def split_text_by_separators(text, separators, keep_separators=True):
80
+ """
81
+ 将文本按指定的分隔符分割为段落
82
+ Args:
83
+ text: 要分割的文本
84
+ separators: 分隔符列表
85
+ keep_separators: 是否在结果中保留分隔符,默认为True
86
+ Returns:
87
+ 包含分割后段落的列表
88
+ """
89
+ pattern = f'({"|".join(map(re.escape, separators))}+)'
90
+ chunks = re.split(pattern, text)
91
+
92
+ paragraphs = []
93
+ current_para = []
94
+
95
+ for i in range(0, len(chunks), 2):
96
+ content = chunks[i]
97
+ separator = chunks[i + 1] if i + 1 < len(chunks) else ''
98
+
99
+ current_para.append(content)
100
+ if keep_separators and separator:
101
+ current_para.append(separator)
102
+
103
+ if content.strip():
104
+ paragraphs.append(''.join(current_para))
105
+ current_para = []
106
+
107
+ return paragraphs
108
+
109
+ def split_text_into_paragraphs(text, keep_separators=True):
110
+ return split_text_by_separators(text, ['\n'], keep_separators)
111
+
112
+ def split_text_into_sentences(text, keep_separators=True):
113
+ return split_text_by_separators(text, ['\n', '。', '?', '!', ';'], keep_separators)
114
+
115
+ def run_and_echo_yield_func(func, *args, **kwargs):
116
+ echo_text = ""
117
+ all_messages = []
118
+ for messages in func(*args, **kwargs):
119
+ all_messages.append(messages)
120
+ new_echo_text = "\n".join(f"{msg['role']}:\n{msg['content']}" for msg in messages)
121
+ if new_echo_text.startswith(echo_text):
122
+ delta_echo_text = new_echo_text[len(echo_text):]
123
+ else:
124
+ echo_text = ""
125
+ print('\n--------------------------------')
126
+ delta_echo_text = new_echo_text
127
+
128
+ print(delta_echo_text, end="")
129
+ echo_text = echo_text + delta_echo_text
130
+ return all_messages
131
+
132
+ def run_yield_func(func, *args, **kwargs):
133
+ gen = func(*args, **kwargs)
134
+ try:
135
+ while True:
136
+ next(gen)
137
+ except StopIteration as e:
138
+ return e.value
139
+
140
+ def split_text_into_chunks(text, max_chunk_size, min_chunk_n, min_chunk_size=1, max_chunk_n=1000):
141
+ def split_paragraph(para):
142
+ mid = len(para) // 2
143
+ split_pattern = r'[。?;]'
144
+ split_points = [m.end() for m in re.finditer(split_pattern, para)]
145
+
146
+ if not split_points:
147
+ raise Exception("没有找到分割点!")
148
+
149
+ closest_point = min(split_points, key=lambda x: abs(x - mid))
150
+ if not para[:closest_point].strip() or not para[closest_point:].strip():
151
+ raise Exception("没有找到分割点!")
152
+
153
+ return para[:closest_point], para[closest_point:]
154
+
155
+ paragraphs = split_text_into_paragraphs(text)
156
+
157
+ assert max_chunk_n >= 1, "max_chunk_n必须大于等于1"
158
+ assert sum(len(p) for p in paragraphs) >= min_chunk_size, f"分割时,输入的文本长度小于要���的min_chunk_size:{min_chunk_size}"
159
+ count = 0 # 防止死循环
160
+ while len(paragraphs) > max_chunk_n or min(len(p) for p in paragraphs) < min_chunk_size:
161
+ assert (count:=count+1) < 1000, "分割进入死循环!"
162
+
163
+ # 找出相邻chunks中和最小的两个进行合并
164
+ min_sum = float('inf')
165
+ min_i = 0
166
+
167
+ for i in range(len(paragraphs) - 1):
168
+ curr_sum = len(paragraphs[i]) + len(paragraphs[i + 1])
169
+ if curr_sum < min_sum:
170
+ min_sum = curr_sum
171
+ min_i = i
172
+
173
+ # 合并这两个chunks
174
+ paragraphs[min_i:min_i + 2] = [''.join(paragraphs[min_i:min_i + 2])]
175
+
176
+ while len(paragraphs) < min_chunk_n or max(len(p) for p in paragraphs) > max_chunk_size:
177
+ assert (count:=count+1) < 1000, "分割进入死循环!"
178
+ longest_para_i = max(range(len(paragraphs)), key=lambda i: len(paragraphs[i]))
179
+ part1, part2 = split_paragraph(paragraphs[longest_para_i])
180
+ if len(part1) < min_chunk_size or len(part2) < min_chunk_size or len(paragraphs) + 1 > max_chunk_n:
181
+ raise Exception("没有找到合适的分割点!")
182
+ paragraphs[longest_para_i:longest_para_i+1] = [part1, part2]
183
+
184
+ return paragraphs
185
+
186
+ def test_split_text_into_chunks():
187
+ # Test case 1: Simple paragraph splitting
188
+ text1 = "这是第一段。这是第二段。这是第三段。"
189
+ result1 = split_text_into_chunks(text1, max_chunk_size=10, min_chunk_n=3)
190
+ print("Test 1 result:", result1)
191
+ assert len(result1) == 3, f"Expected 3 chunks, got {len(result1)}"
192
+
193
+
194
+ # Test case 2: Long paragraph splitting
195
+ text2 = "这是一个很长的段落,包含了很多句子。它应该被分割成多个小块。这里有一些标点符号,比如句号。还有问号?以及分号;这些都可以用来分割文本。"
196
+ result2 = split_text_into_chunks(text2, max_chunk_size=20, min_chunk_n=4)
197
+ print("Test 2 result:", result2)
198
+ assert len(result2) >= 4, f"Expected at least 4 chunks, got {len(result2)}"
199
+ assert all(len(chunk) <= 20 for chunk in result2), "Some chunks are longer than max_chunk_size"
200
+
201
+ # Test case 3: Text with newlines
202
+ text3 = "第一段。\n\n第二段。\n第三段。\n\n第四段很长,需要被分割。这是第四段的继续。"
203
+ result3 = split_text_into_chunks(text3, max_chunk_size=15, min_chunk_n=5)
204
+ print("Test 3 result:", result3)
205
+ assert len(result3) >= 5, f"Expected at least 5 chunks, got {len(result3)}"
206
+ assert all(len(chunk) <= 15 for chunk in result3), "Some chunks are longer than max_chunk_size"
207
+
208
+ print("All tests passed!")
209
+
210
+ if __name__ == "__main__":
211
+ print(detect_max_edit_span("我吃西红柿", "我不喜欢吃西红柿"))
212
+ print(detect_max_edit_span("我吃西红柿", "不喜欢吃西红柿"))
213
+ print(detect_max_edit_span("我吃西红柿", "我不喜欢吃"))
214
+ print(detect_max_edit_span("我吃西红柿", "你不喜欢吃西瓜"))
215
+
216
+ test_split_text_into_chunks()
custom/根据提纲创作正文/天蚕土豆风格.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 你是一个网文大神作家,外号天蚕豌豆,擅长写玄幻网文,代表作有《斗破天空》,《舞动乾坤》,《我主宰》。
2
+
3
+ 你的常用反派话语有:
4
+ 此子断不可留,否则日后必成大患!
5
+ 做事留一线,日后好相见。
6
+ 一口鲜血夹杂着破碎的内脏喷出。
7
+ 能把我逼到这种地步,你足以自傲了。
8
+ 放眼XXX,你也算是凤毛麟角般的存在。
9
+
10
+ 你的常用词语有:
11
+ 黯然销魂、神出鬼没、格格不入、微不足道、窃窃私语、给我破、给我碎、摧枯拉朽、倒吸一口凉气、一脚踢开、旋即、苦笑、美眸、一拳、放眼、桀桀、负手而立、摧枯拉朽、黑袍老者、摸了摸鼻子、妮子、贝齿紧咬着红唇、幽怨、浊气、凤毛麟角、一声娇喝、恐怖如斯、纤纤玉手、头角峥嵘、桀桀桀、虎躯一震、苦笑一声、三千青丝
12
+
13
+
14
+ 下面我会给你一段网文提纲,需要你对其进行润色或重写,输出网文正文。
custom/根据提纲创作正文/对草稿进行润色.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 你是一个网文作家,下面我会给你一段简陋粗略的网文草稿,需要你对其进行润色或重写,输出网文正文。
2
+
3
+ 在创作的过程中,你需要注意以下事项:
4
+ 1. 在草稿的基础上进行创作,不要过度延申,不要在结尾进行总结。
5
+ 2. 对于草稿中内容,在正文中需要用小说家的口吻去描写,包括语言、行为、人物、环境描写等。
6
+ 3. 对于草稿中缺失的部分,在正文中需要进行补全。
7
+
healthcheck.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import http.client
2
+ import sys
3
+ import os
4
+
5
+ BACKEND_PORT = int(os.environ.get('BACKEND_PORT', 7869))
6
+
7
+
8
+ def check_health():
9
+ try:
10
+ conn = http.client.HTTPConnection("localhost", BACKEND_PORT)
11
+ conn.request("GET", "/health")
12
+ response = conn.getresponse()
13
+ if response.status == 200:
14
+ print("Health check passed")
15
+ return True
16
+ else:
17
+ print(f"Health check failed: {response.status}")
18
+ return False
19
+ except Exception as e:
20
+ print(f"Health check failed: {e}", file=sys.stderr)
21
+ return False
22
+
23
+ if __name__ == "__main__":
24
+ sys.exit(0 if check_health() else 1)
llm_api/__init__.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Optional, Generator
2
+
3
+ from .mongodb_cache import llm_api_cache
4
+ from .baidu_api import stream_chat_with_wenxin, wenxin_model_config
5
+ from .doubao_api import stream_chat_with_doubao, doubao_model_config
6
+ from .chat_messages import ChatMessages
7
+ from .openai_api import stream_chat_with_gpt, gpt_model_config
8
+ from .zhipuai_api import stream_chat_with_zhipuai, zhipuai_model_config
9
+
10
+ class ModelConfig(dict):
11
+ def __init__(self, model: str, **options):
12
+ super().__init__(**options)
13
+ self['model'] = model
14
+ self.validate()
15
+
16
+ def validate(self):
17
+ def check_key(provider, keys):
18
+ for key in keys:
19
+ if key not in self:
20
+ raise ValueError(f"{provider}的API设置中未传入: {key}")
21
+ elif not self[key].strip():
22
+ raise ValueError(f"{provider}的API设置中未配置: {key}")
23
+
24
+ if self['model'] in wenxin_model_config:
25
+ check_key('文心一言', ['ak', 'sk'])
26
+ elif self['model'] in doubao_model_config:
27
+ check_key('豆包', ['api_key', 'endpoint_id'])
28
+ elif self['model'] in zhipuai_model_config:
29
+ check_key('智谱AI', ['api_key'])
30
+ elif self['model'] in gpt_model_config or True:
31
+ # 其他模型名默认采用openai接口调用
32
+ check_key('OpenAI', ['api_key'])
33
+
34
+ if 'max_tokens' not in self:
35
+ raise ValueError('ModelConfig未传入key: max_tokens')
36
+ else:
37
+ assert self['max_tokens'] <= 4_096, 'max_tokens最大为4096!'
38
+
39
+
40
+ def get_api_keys(self) -> Dict[str, str]:
41
+ return {k: v for k, v in self.items() if k not in ['model']}
42
+
43
+ @llm_api_cache()
44
+ def stream_chat(model_config: ModelConfig, messages: list, response_json=False) -> Generator:
45
+ if isinstance(model_config, dict):
46
+ model_config = ModelConfig(**model_config)
47
+
48
+ model_config.validate()
49
+
50
+ messages = ChatMessages(messages, model=model_config['model'])
51
+
52
+ assert model_config['max_tokens'] <= 4096, 'max_tokens最大为4096!'
53
+
54
+ if messages.count_message_tokens() > model_config['max_tokens']:
55
+ raise Exception(f'请求的文本过长,超过最大tokens:{model_config["max_tokens"]}。')
56
+
57
+ yield messages
58
+
59
+ if model_config['model'] in wenxin_model_config:
60
+ result = yield from stream_chat_with_wenxin(
61
+ messages,
62
+ model=model_config['model'],
63
+ ak=model_config['ak'],
64
+ sk=model_config['sk'],
65
+ max_tokens=model_config['max_tokens'],
66
+ response_json=response_json
67
+ )
68
+ elif model_config['model'] in doubao_model_config: # doubao models
69
+ result = yield from stream_chat_with_doubao(
70
+ messages,
71
+ model=model_config['model'],
72
+ endpoint_id=model_config['endpoint_id'],
73
+ api_key=model_config['api_key'],
74
+ max_tokens=model_config['max_tokens'],
75
+ response_json=response_json
76
+ )
77
+ elif model_config['model'] in zhipuai_model_config: # zhipuai models
78
+ result = yield from stream_chat_with_zhipuai(
79
+ messages,
80
+ model=model_config['model'],
81
+ api_key=model_config['api_key'],
82
+ max_tokens=model_config['max_tokens'],
83
+ response_json=response_json
84
+ )
85
+ elif model_config['model'] in gpt_model_config or True: # openai models或其他兼容openai接口的模型
86
+ result = yield from stream_chat_with_gpt(
87
+ messages,
88
+ model=model_config['model'],
89
+ api_key=model_config['api_key'],
90
+ base_url=model_config.get('base_url'),
91
+ proxies=model_config.get('proxies'),
92
+ max_tokens=model_config['max_tokens'],
93
+ response_json=response_json
94
+ )
95
+
96
+ result.finished = True
97
+ yield result
98
+
99
+ return result
100
+
101
+ def test_stream_chat(model_config: ModelConfig):
102
+ messages = [{"role": "user", "content": "1+1=?直接输出答案即可:"}]
103
+ for response in stream_chat(model_config, messages, use_cache=False):
104
+ yield response.response
105
+
106
+ return response
107
+
108
+ # 导出必要的函数和配置
109
+ __all__ = ['ChatMessages', 'stream_chat', 'wenxin_model_config', 'doubao_model_config', 'gpt_model_config', 'zhipuai_model_config', 'ModelConfig']
llm_api/baidu_api.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import qianfan
2
+ from .chat_messages import ChatMessages
3
+
4
+ # ak和sk获取:https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
5
+
6
+ # 价格:https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
7
+
8
+ wenxin_model_config = {
9
+ "ERNIE-3.5-8K":{
10
+ "Pricing": (0.0008, 0.002),
11
+ "currency_symbol": '¥',
12
+ },
13
+ "ERNIE-4.0-8K":{
14
+ "Pricing": (0.03, 0.09),
15
+ "currency_symbol": '¥',
16
+ },
17
+ "ERNIE-Novel-8K":{
18
+ "Pricing": (0.04, 0.12),
19
+ "currency_symbol": '¥',
20
+ }
21
+ }
22
+
23
+
24
+ def stream_chat_with_wenxin(messages, model='ERNIE-Bot', response_json=False, ak=None, sk=None, max_tokens=6000):
25
+ if ak is None or sk is None:
26
+ raise Exception('未提供有效的 ak 和 sk!')
27
+
28
+ client = qianfan.ChatCompletion(ak=ak, sk=sk)
29
+
30
+ chatstream = client.do(model=model,
31
+ system=messages[0]['content'] if messages[0]['role'] == 'system' else None,
32
+ messages=messages if messages[0]['role'] != 'system' else messages[1:],
33
+ stream=True,
34
+ response_format='json_object' if response_json else 'text'
35
+ )
36
+
37
+ messages.append({'role': 'assistant', 'content': ''})
38
+ content = ''
39
+ for part in chatstream:
40
+ content += part['body']['result'] or ''
41
+ messages[-1]['content'] = content
42
+ yield messages
43
+
44
+ return messages
45
+
46
+
47
+ if __name__ == '__main__':
48
+ pass
llm_api/chat_messages.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import re
3
+ import json
4
+ import os
5
+
6
+ def count_characters(text):
7
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]+')
8
+ english_pattern = re.compile(r'[a-zA-Z]+')
9
+ other_pattern = re.compile(r'[^\u4e00-\u9fffa-zA-Z]+')
10
+
11
+ chinese_characters = chinese_pattern.findall(text)
12
+ english_characters = english_pattern.findall(text)
13
+ other_characters = other_pattern.findall(text)
14
+
15
+ chinese_count = sum(len(char) for char in chinese_characters)
16
+ english_count = sum(len(char) for char in english_characters)
17
+ other_count = sum(len(char) for char in other_characters)
18
+
19
+ return chinese_count, english_count, other_count
20
+
21
+
22
+ model_config = {}
23
+
24
+
25
+ model_prices = {}
26
+ try:
27
+ model_prices_path = os.path.join(os.path.dirname(__file__), 'model_prices.json')
28
+ with open(model_prices_path, 'r') as f:
29
+ model_prices = json.load(f)
30
+ except Exception as e:
31
+ print(f"Warning: Failed to load model_prices.json: {e}")
32
+
33
+ class ChatMessages(list):
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args)
36
+ self.model = kwargs['model'] if 'model' in kwargs else None
37
+ self.finished = False
38
+
39
+ assert 'currency_symbol' not in kwargs
40
+
41
+ if not model_config:
42
+ from .baidu_api import wenxin_model_config
43
+ from .doubao_api import doubao_model_config
44
+ from .openai_api import gpt_model_config
45
+ from .zhipuai_api import zhipuai_model_config
46
+ model_config.update({**wenxin_model_config, **doubao_model_config, **gpt_model_config, **zhipuai_model_config})
47
+
48
+ def __getitem__(self, index):
49
+ result = super().__getitem__(index)
50
+ if isinstance(index, slice):
51
+ return ChatMessages(result, model=self.model)
52
+ return result
53
+
54
+ def __add__(self, other):
55
+ if isinstance(other, list):
56
+ return ChatMessages(super().__add__(other), model=self.model)
57
+ return NotImplemented
58
+
59
+ def count_message_tokens(self):
60
+ return self.get_estimated_tokens()
61
+
62
+ def copy(self):
63
+ return ChatMessages(self, model=self.model)
64
+
65
+ def get_estimated_tokens(self):
66
+ num_tokens = 0
67
+ for message in self:
68
+ for key, value in message.items():
69
+ chinese_count, english_count, other_count = count_characters(value)
70
+ num_tokens += chinese_count // 2 + english_count // 5 + other_count // 2
71
+ return num_tokens
72
+
73
+ def get_prompt_messages_hash(self):
74
+ # 转换为JSON字符串并创建哈希
75
+ cache_string = json.dumps(self.prompt_messages, sort_keys=True)
76
+ return hashlib.md5(cache_string.encode()).hexdigest()
77
+
78
+ @property
79
+ def cost(self):
80
+ if len(self) == 0:
81
+ return 0
82
+
83
+ if self.model in model_config:
84
+ return model_config[self.model]["Pricing"][0] * self[:-1].count_message_tokens() / 1_000 + model_config[self.model]["Pricing"][1] * self[-1:].count_message_tokens() / 1_000
85
+ elif self.model in model_prices:
86
+ return (
87
+ model_prices[self.model]["input_cost_per_token"] * self[:-1].count_message_tokens() +
88
+ model_prices[self.model]["output_cost_per_token"] * self[-1:].count_message_tokens()
89
+ )
90
+ return 0
91
+
92
+ @property
93
+ def response(self):
94
+ return self[-1]['content'] if self[-1]['role'] == 'assistant' else ''
95
+
96
+ @property
97
+ def prompt_messages(self):
98
+ return self[:-1] if self.response else self
99
+
100
+ @property
101
+ def currency_symbol(self):
102
+ if self.model in model_config:
103
+ return model_config[self.model]["currency_symbol"]
104
+ else:
105
+ return '$'
106
+
107
+ @property
108
+ def cost_info(self):
109
+ formatted_cost = f"{self.cost:.7f}".rstrip('0').rstrip('.')
110
+ return f"{self.model}: {formatted_cost}{self.currency_symbol}"
111
+
112
+ def print(self):
113
+ for message in self:
114
+ print(f"{message['role']}".center(100, '-') + '\n')
115
+ print(message['content'])
116
+ print()
llm_api/doubao_api.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from .chat_messages import ChatMessages
3
+
4
+ doubao_model_config = {
5
+ "doubao-lite-32k":{
6
+ "Pricing": (0.0003, 0.0006),
7
+ "currency_symbol": '¥',
8
+ },
9
+ "doubao-lite-128k":{
10
+ "Pricing": (0.0008, 0.001),
11
+ "currency_symbol": '¥',
12
+ },
13
+ "doubao-pro-32k":{
14
+ "Pricing": (0.0008, 0.002),
15
+ "currency_symbol": '¥',
16
+ },
17
+ "doubao-pro-128k":{
18
+ "Pricing": (0.005, 0.009),
19
+ "currency_symbol": '¥',
20
+ },
21
+ }
22
+
23
+ def stream_chat_with_doubao(messages, model='doubao-lite-32k', endpoint_id=None, response_json=False, api_key=None, max_tokens=32000):
24
+ if api_key is None:
25
+ raise Exception('未提供有效的 api_key!')
26
+ if endpoint_id is None:
27
+ raise Exception('未提供有效的 endpoint_id!')
28
+
29
+ client = OpenAI(
30
+ api_key=api_key,
31
+ base_url="https://ark.cn-beijing.volces.com/api/v3",
32
+ )
33
+
34
+ stream = client.chat.completions.create(
35
+ model=endpoint_id,
36
+ messages=messages,
37
+ stream=True,
38
+ response_format={ "type": "json_object" } if response_json else None
39
+ )
40
+
41
+ messages.append({'role': 'assistant', 'content': ''})
42
+ content = ''
43
+ for chunk in stream:
44
+ if chunk.choices:
45
+ delta_content = chunk.choices[0].delta.content or ''
46
+ content += delta_content
47
+ messages[-1]['content'] = content
48
+ yield messages
49
+
50
+ return messages
51
+
52
+ if __name__ == '__main__':
53
+ pass
llm_api/model_prices.json ADDED
The diff for this file is too large to render. See raw diff
 
llm_api/mongodb_cache.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import functools
3
+ from typing import Generator, Any
4
+ from pymongo import MongoClient
5
+ import hashlib
6
+ import json
7
+ import datetime
8
+ import random
9
+
10
+ from config import ENABLE_MONOGODB, MONOGODB_DB_NAME, ENABLE_MONOGODB_CACHE, CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY
11
+
12
+ from .chat_messages import ChatMessages
13
+ from .mongodb_cost import record_api_cost, check_cost_limits
14
+ from .mongodb_init import mongo_client as client
15
+
16
+ def create_cache_key(func_name: str, args: tuple, kwargs: dict) -> str:
17
+ """创建缓存键"""
18
+ # 将参数转换为可序列化的格式
19
+ cache_dict = {
20
+ 'func_name': func_name,
21
+ 'args': args,
22
+ 'kwargs': kwargs
23
+ }
24
+ # 转换为JSON字符串并创建哈希
25
+ cache_string = json.dumps(cache_dict, sort_keys=True)
26
+ return hashlib.md5(cache_string.encode()).hexdigest()
27
+
28
+
29
+
30
+ def llm_api_cache():
31
+ """MongoDB缓存装饰器"""
32
+ db_name=MONOGODB_DB_NAME
33
+ collection_name='stream_chat'
34
+
35
+ def dummy_decorator(func):
36
+ @functools.wraps(func)
37
+ def wrapper(*args, **kwargs):
38
+ # 移除 use_cache 参数,避免传递给原函数
39
+ kwargs.pop('use_cache', None)
40
+ return func(*args, **kwargs)
41
+ return wrapper
42
+
43
+
44
+ if not ENABLE_MONOGODB:
45
+ return dummy_decorator
46
+
47
+ def decorator(func):
48
+ @functools.wraps(func)
49
+ def wrapper(*args, **kwargs):
50
+ check_cost_limits()
51
+
52
+ use_cache = kwargs.pop('use_cache', True) # pop很重要
53
+
54
+ if not ENABLE_MONOGODB_CACHE:
55
+ use_cache = False
56
+
57
+ db = client[db_name]
58
+ collection = db[collection_name]
59
+
60
+ # 创建缓存键
61
+ cache_key = create_cache_key(func.__name__, args, kwargs)
62
+
63
+ # 检查缓存
64
+ if use_cache:
65
+ cached_data = list(collection.aggregate([
66
+ {'$match': {'cache_key': cache_key}},
67
+ {'$sample': {'size': 1}}
68
+ ]))
69
+ cached_data = cached_data[0] if cached_data else None
70
+ if cached_data:
71
+ # 如果有缓存,yield缓存的结果
72
+ messages = ChatMessages(cached_data['return_value'])
73
+ messages.model = args[0]['model']
74
+ for item in cached_data['yields']:
75
+ sacled_delay = min(item['delay'] / CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY)
76
+ if sacled_delay > 0: time.sleep(sacled_delay) # 应用加速倍数
77
+ else: continue
78
+ if item['index'] > 0:
79
+ yield messages.prompt_messages + [{'role': 'assistant', 'content': messages.response[:item['index']]}]
80
+ else:
81
+ yield messages.prompt_messages
82
+ messages.finished = True
83
+ yield messages
84
+ return messages
85
+
86
+ # 如果没有缓存,执行原始函数并记录结果
87
+ yields_data = []
88
+ last_time = time.time()
89
+
90
+ generator = func(*args, **kwargs)
91
+
92
+ try:
93
+ while True:
94
+ current_time = time.time()
95
+ value = next(generator)
96
+ delay = current_time - last_time
97
+
98
+ yields_data.append({
99
+ 'index': len(value.response),
100
+ 'delay': delay
101
+ })
102
+
103
+ last_time = current_time
104
+ yield value
105
+
106
+ except StopIteration as e:
107
+ return_value = e.value
108
+
109
+ # 记录API调用费用
110
+ record_api_cost(return_value)
111
+
112
+ # 存储到MongoDB
113
+ cache_data = {
114
+ 'created_at':datetime.datetime.now(),
115
+ 'return_value': return_value,
116
+ 'func_name': func.__name__,
117
+ 'args': args,
118
+ 'kwargs': kwargs,
119
+ 'yields': yields_data,
120
+ 'cache_key': cache_key,
121
+ }
122
+ collection.insert_one(cache_data)
123
+
124
+ return return_value
125
+
126
+ return wrapper
127
+ return decorator
llm_api/mongodb_cost.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ from config import API_COST_LIMITS, MONOGODB_DB_NAME
4
+
5
+ from .chat_messages import ChatMessages
6
+ from .mongodb_init import mongo_client as client
7
+
8
+ def record_api_cost(messages: ChatMessages):
9
+ """记录API调用费用"""
10
+
11
+ db = client[MONOGODB_DB_NAME]
12
+ collection = db['api_cost']
13
+
14
+ cost_data = {
15
+ 'created_at': datetime.datetime.now(),
16
+ 'model': messages.model,
17
+ 'cost': messages.cost,
18
+ 'currency_symbol': messages.currency_symbol,
19
+ 'input_tokens': messages[:-1].count_message_tokens(),
20
+ 'output_tokens': messages[-1:].count_message_tokens(),
21
+ 'total_tokens': messages.count_message_tokens()
22
+ }
23
+ collection.insert_one(cost_data)
24
+
25
+ def get_model_cost_stats(start_date: datetime.datetime, end_date: datetime.datetime) -> list:
26
+ """获取指定时间段内的模型调用费用统计"""
27
+ pipeline = [
28
+ {
29
+ '$match': {
30
+ 'created_at': {
31
+ '$gte': start_date,
32
+ '$lte': end_date
33
+ }
34
+ }
35
+ },
36
+ {
37
+ '$group': {
38
+ '_id': '$model',
39
+ 'total_cost': { '$sum': '$cost' },
40
+ 'total_calls': { '$sum': 1 },
41
+ 'total_input_tokens': { '$sum': '$input_tokens' },
42
+ 'total_output_tokens': { '$sum': '$output_tokens' },
43
+ 'total_tokens': { '$sum': '$total_tokens' },
44
+ 'avg_cost_per_call': { '$avg': '$cost' },
45
+ 'currency_symbol': { '$first': '$currency_symbol' }
46
+ }
47
+ },
48
+ {
49
+ '$project': {
50
+ 'model': '$_id',
51
+ 'total_cost': { '$round': ['$total_cost', 4] },
52
+ 'total_calls': 1,
53
+ 'total_input_tokens': 1,
54
+ 'total_output_tokens': 1,
55
+ 'total_tokens': 1,
56
+ 'avg_cost_per_call': { '$round': ['$avg_cost_per_call', 4] },
57
+ 'currency_symbol': 1,
58
+ '_id': 0
59
+ }
60
+ },
61
+ {
62
+ '$sort': { 'total_cost': -1 }
63
+ }
64
+ ]
65
+
66
+ # 直接从 api_cost 集合查询数据
67
+ db = client[MONOGODB_DB_NAME]
68
+ collection = db['api_cost']
69
+
70
+ stats = list(collection.aggregate(pipeline))
71
+ return stats
72
+
73
+ # 使用示例:
74
+ def print_cost_report(days: int = 30, hours: int = 0):
75
+ """打印最近N天的费用报告"""
76
+ end_date = datetime.datetime.now()
77
+ start_date = end_date - datetime.timedelta(days=days, hours=hours)
78
+
79
+ stats = get_model_cost_stats(start_date, end_date)
80
+
81
+ print(f"\n=== API Cost Report ({start_date.date()} to {end_date.date()}) ===")
82
+ for model_stat in stats:
83
+ print(f"\nModel: {model_stat['model']}")
84
+ print(f"Total Cost: {model_stat['currency_symbol']}{model_stat['total_cost']:.4f}")
85
+ print(f"Total Calls: {model_stat['total_calls']}")
86
+ print(f"Total Tokens: {model_stat['total_tokens']:,}")
87
+ print(f"Avg Cost/Call: {model_stat['currency_symbol']}{model_stat['avg_cost_per_call']:.4f}")
88
+
89
+ def check_cost_limits() -> bool:
90
+ """
91
+ 检查API调用费用是否超过限制
92
+ 返回: 如果未超过限制返回True,否则返回False
93
+ """
94
+ now = datetime.datetime.now()
95
+ hour_ago = now - datetime.timedelta(hours=1)
96
+ day_ago = now - datetime.timedelta(days=1)
97
+
98
+ # 获取统计数据
99
+ hour_stats = get_model_cost_stats(hour_ago, now)
100
+ day_stats = get_model_cost_stats(day_ago, now)
101
+
102
+ # 计算总费用并根据需要转换为人民币
103
+ hour_total_rmb = sum(
104
+ stat['total_cost'] * (API_COST_LIMITS['USD_TO_RMB_RATE'] if stat['currency_symbol'] == '$' else 1)
105
+ for stat in hour_stats
106
+ )
107
+ day_total_rmb = sum(
108
+ stat['total_cost'] * (API_COST_LIMITS['USD_TO_RMB_RATE'] if stat['currency_symbol'] == '$' else 1)
109
+ for stat in day_stats
110
+ )
111
+
112
+ # 检查是否超过限制
113
+ if hour_total_rmb >= API_COST_LIMITS['HOURLY_LIMIT_RMB']:
114
+ print(f"警告:最近1小时API费用(¥{hour_total_rmb:.2f})超过限制(¥{API_COST_LIMITS['HOURLY_LIMIT_RMB']})")
115
+ raise Exception("最近1小时内API调用费用超过设定上限!")
116
+
117
+ if day_total_rmb >= API_COST_LIMITS['DAILY_LIMIT_RMB']:
118
+ print(f"警告:最近24小时API费用(¥{day_total_rmb:.2f})超过限制(¥{API_COST_LIMITS['DAILY_LIMIT_RMB']})")
119
+ raise Exception("最近1天内API调用费用超过设定上限!")
120
+
121
+ return True
llm_api/mongodb_init.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ from config import ENABLE_MONOGODB
3
+ from pymongo import MongoClient
4
+
5
+ # 从环境变量获取 MongoDB URI,如果没有则使用默认值
6
+ mongo_uri = os.getenv('MONGODB_URI', 'mongodb://localhost:27017/')
7
+ mongo_client = MongoClient(mongo_uri) if ENABLE_MONOGODB else None
llm_api/openai_api.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from openai import OpenAI
3
+ from .chat_messages import ChatMessages
4
+
5
+ # Pricing reference: https://openai.com/api/pricing/
6
+ gpt_model_config = {
7
+ "gpt-4o": {
8
+ "Pricing": (2.50/1000, 10.00/1000),
9
+ "currency_symbol": '$',
10
+ },
11
+ "gpt-4o-mini": {
12
+ "Pricing": (0.15/1000, 0.60/1000),
13
+ "currency_symbol": '$',
14
+ },
15
+ "o1-preview": {
16
+ "Pricing": (15/1000, 60/1000),
17
+ "currency_symbol": '$',
18
+ },
19
+ "o1-mini": {
20
+ "Pricing": (3/1000, 12/1000),
21
+ "currency_symbol": '$',
22
+ },
23
+ }
24
+ # https://platform.openai.com/docs/guides/reasoning
25
+
26
+ def stream_chat_with_gpt(messages, model='gpt-3.5-turbo-1106', response_json=False, api_key=None, base_url=None, max_tokens=4_096, n=1, proxies=None):
27
+ if api_key is None:
28
+ raise Exception('未提供有效的 api_key!')
29
+
30
+ client_params = {
31
+ "api_key": api_key,
32
+ }
33
+
34
+ if base_url:
35
+ client_params['base_url'] = base_url
36
+
37
+ if proxies:
38
+ httpx_client = httpx.Client(proxy=proxies)
39
+ client_params["http_client"] = httpx_client
40
+
41
+ client = OpenAI(**client_params)
42
+
43
+ if model in ['o1-preview', ] and messages[0]['role'] == 'system':
44
+ messages[0:1] = [{'role': 'user', 'content': messages[0]['content']}, {'role': 'assistant', 'content': ''}]
45
+
46
+ chatstream = client.chat.completions.create(
47
+ stream=True,
48
+ model=model,
49
+ messages=messages,
50
+ max_tokens=max_tokens,
51
+ response_format={ "type": "json_object" } if response_json else None,
52
+ n=n
53
+ )
54
+
55
+ messages.append({'role': 'assistant', 'content': ''})
56
+ content = ['' for _ in range(n)]
57
+ for part in chatstream:
58
+ for choice in part.choices:
59
+ content[choice.index] += choice.delta.content or ''
60
+ messages[-1]['content'] = content if n > 1 else content[0]
61
+ yield messages
62
+
63
+ return messages
64
+
65
+
66
+ if __name__ == '__main__':
67
+ pass
llm_api/sparkai_api.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
2
+ from sparkai.core.messages import ChatMessage as SparkMessage
3
+
4
+ #星火认知大模型Spark Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
5
+ SPARKAI_URL = 'wss://spark-api.xf-yun.com/v4.0/chat'
6
+ #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
7
+ SPARKAI_APP_ID = '01793781'
8
+ SPARKAI_API_SECRET = 'YzJkNTI5N2Q5NDY4N2RlNWI5YjA5ZDM4'
9
+ SPARKAI_API_KEY = '5dd33ea830aff0c9dff18e2561a5e6c7'
10
+ #星火认知大模型Spark Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
11
+ SPARKAI_DOMAIN = '4.0Ultra'
12
+
13
+ """
14
+ 5dd33ea830aff0c9dff18e2561a5e6c7&YzJkNTI5N2Q5NDY4N2RlNWI5YjA5ZDM4&01793781
15
+
16
+ domain值:
17
+ lite指向Lite版本;
18
+ generalv3指向Pro版本;
19
+ pro-128k指向Pro-128K版本;
20
+ generalv3.5指向Max版本;
21
+ max-32k指向Max-32K版本;
22
+ 4.0Ultra指向4.0 Ultra版本;
23
+
24
+
25
+ Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra:
26
+ wss://spark-api.xf-yun.com/v4.0/chat
27
+ Spark Max-32K请求地址,对应的domain参数为max-32k
28
+ wss://spark-api.xf-yun.com/chat/max-32k
29
+ Spark Max请求地址,对应的domain参数为generalv3.5
30
+ wss://spark-api.xf-yun.com/v3.5/chat
31
+ Spark Pro-128K请求地址,对应的domain参数为pro-128k:
32
+ wss://spark-api.xf-yun.com/chat/pro-128k
33
+ Spark Pro请求地址,对应的domain参数为generalv3:
34
+ wss://spark-api.xf-yun.com/v3.1/chat
35
+ Spark Lite请求地址,对应的domain参数为lite:
36
+ wss://spark-api.xf-yun.com/v1.1/chat
37
+ """
38
+
39
+
40
+ sparkai_model_config = {
41
+ "spark-4.0-ultra": {
42
+ "Pricing": (0, 0),
43
+ "currency_symbol": '¥',
44
+ "url": "wss://spark-api.xf-yun.com/v4.0/chat",
45
+ "domain": "4.0Ultra"
46
+ }
47
+ }
48
+
49
+
50
+
51
+ if __name__ == '__main__':
52
+ spark = ChatSparkLLM(
53
+ spark_api_url=SPARKAI_URL,
54
+ spark_app_id=SPARKAI_APP_ID,
55
+ spark_api_key=SPARKAI_API_KEY,
56
+ spark_api_secret=SPARKAI_API_SECRET,
57
+ spark_llm_domain=SPARKAI_DOMAIN,
58
+ streaming=True,
59
+ )
60
+ messages = [SparkMessage(
61
+ role="user",
62
+ content='你好呀'
63
+ )]
64
+ a = spark.stream(messages)
65
+ for message in a:
66
+ print(message)
llm_api/zhipuai_api.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from zhipuai import ZhipuAI
2
+ from .chat_messages import ChatMessages
3
+
4
+ # Pricing
5
+ # https://open.bigmodel.cn/pricing
6
+ # GLM-4-Plus 0.05¥/1000 tokens, GLM-4-Air 0.001¥/1000 tokens, GLM-4-FlashX 0.0001¥/1000 tokens, , GLM-4-Flash 0¥/1000 tokens
7
+
8
+ # Models
9
+ # https://bigmodel.cn/dev/howuse/model
10
+ # glm-4-plus、glm-4-air、 glm-4-flashx 、 glm-4-flash
11
+
12
+
13
+
14
+ zhipuai_model_config = {
15
+ "glm-4-plus": {
16
+ "Pricing": (0.05, 0.05),
17
+ "currency_symbol": '¥',
18
+ },
19
+ "glm-4-air": {
20
+ "Pricing": (0.001, 0.001),
21
+ "currency_symbol": '¥',
22
+ },
23
+ "glm-4-flashx": {
24
+ "Pricing": (0.0001, 0.0001),
25
+ "currency_symbol": '¥',
26
+ },
27
+ "glm-4-flash": {
28
+ "Pricing": (0, 0),
29
+ "currency_symbol": '¥',
30
+ },
31
+ }
32
+
33
+ def stream_chat_with_zhipuai(messages, model='glm-4-flash', response_json=False, api_key=None, max_tokens=4_096):
34
+ if api_key is None:
35
+ raise Exception('未提供有效的 api_key!')
36
+
37
+ client = ZhipuAI(api_key=api_key)
38
+
39
+ response = client.chat.completions.create(
40
+ model=model,
41
+ messages=messages,
42
+ stream=True,
43
+ max_tokens=max_tokens
44
+ )
45
+
46
+ messages.append({'role': 'assistant', 'content': ''})
47
+ for chunk in response:
48
+ messages[-1]['content'] += chunk.choices[0].delta.content or ''
49
+ yield messages
50
+
51
+ return messages
52
+
53
+ if __name__ == '__main__':
54
+ pass
prompts/baseprompt.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from prompts.chat_utils import chat, log
4
+ from prompts.pf_parse_chat import parse_chat
5
+ from prompts.prompt_utils import load_text, match_code_block
6
+
7
+ def parser(response_msgs):
8
+ content = response_msgs.response
9
+ blocks = match_code_block(content)
10
+ if blocks:
11
+ concat_blocks = "\n".join(blocks)
12
+ if concat_blocks.strip():
13
+ content = concat_blocks
14
+ return content
15
+
16
+
17
+ def clean_txt_content(content):
18
+ """Remove comments and trim empty lines from txt content"""
19
+ lines = []
20
+ for line in content.split('\n'):
21
+ if not line.startswith('//'):
22
+ lines.append(line)
23
+ return '\n'.join(lines).strip()
24
+
25
+
26
+ def load_prompt(dirname, name):
27
+ txt_path = os.path.join(dirname, f"{name}.txt")
28
+ text = load_text(txt_path)
29
+
30
+ return text
31
+
32
+ def parse_prompt(text, **kwargs):
33
+ """
34
+ 从text中解析PromptMessages。
35
+ 对于传入的key-values, key可以多也可以少。
36
+ 少的key和value为空的那轮对话会被删除。
37
+ 多的key不会管。
38
+ """
39
+ content = clean_txt_content(text)
40
+
41
+ # Find all format keys in content using regex
42
+ format_keys = set(re.findall(r'\{(\w+)\}', content))
43
+
44
+ formatted_kwargs = {k: kwargs.get(k, '__delete__') or '__delete__' for k in format_keys}
45
+ formatted_kwargs = {k: f"```\n{v.strip()}\n```" for k, v in formatted_kwargs.items()}
46
+ prompt = content.format(**formatted_kwargs) if format_keys else content
47
+ messages = parse_chat(prompt)
48
+ for i in range(len(messages)-2, -1, -1):
49
+ if '__delete__' in messages[i]['content']:
50
+ assert messages[i]['role'] == 'user' and messages[i+1]['role'] == 'assistant', "__delete__ must be in user's message"
51
+ messages.pop(i)
52
+ messages.pop(i)
53
+
54
+ return messages
55
+
56
+
57
+ def parse_input_keys(text):
58
+ # Use regex to find the input keys line and parse keys
59
+ match = re.search(r'//\s*输入:(.*?)(?:\n|$)', text)
60
+ if not match:
61
+ return []
62
+
63
+ keys_str = match.group(1).strip()
64
+
65
+ keys = [k.strip() for k in keys_str.split(',') if k.strip()]
66
+
67
+ return keys
68
+
69
+ def main(model, dirname, user_prompt_text, **kwargs):
70
+ # Load system prompt
71
+ system_prompt = parse_prompt(load_prompt(dirname, "system_prompt"), **kwargs)
72
+
73
+ load_from_file_flag = False
74
+ if os.path.exists(os.path.join(dirname, user_prompt_text)):
75
+ user_prompt_text = load_prompt(dirname, user_prompt_text)
76
+ load_from_file_flag = True
77
+ else:
78
+ if not re.search(r'^user:\n', user_prompt_text, re.MULTILINE):
79
+ user_prompt_text = f"user:\n{user_prompt_text}"
80
+
81
+ user_prompt = parse_prompt(user_prompt_text, **kwargs)
82
+
83
+ context_input_keys = parse_input_keys(user_prompt_text)
84
+ if not context_input_keys:
85
+ assert not load_from_file_flag, "从本地文件加载Prompt时,本地文件中注释必须指明输入!"
86
+ context_kwargs = kwargs
87
+ else:
88
+ context_kwargs = {k: kwargs[k] for k in context_input_keys}
89
+ assert all(context_kwargs.values()), "Missing required context keys"
90
+
91
+ context_prompt = parse_prompt(load_prompt(dirname, "context_prompt"), **context_kwargs)
92
+
93
+ # Combine all prompts
94
+ final_prompt = system_prompt + context_prompt + user_prompt
95
+
96
+ # Chat and parse results
97
+ for response_msgs in chat(final_prompt, None, model, parse_chat=False):
98
+ text = parser(response_msgs)
99
+ ret = {'text': text, 'response_msgs': response_msgs}
100
+ yield ret
101
+
102
+ return ret
103
+
104
+
105
+
prompts/chat_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .pf_parse_chat import parse_chat as pf_parse_chat
3
+
4
+ from llm_api import ModelConfig, stream_chat
5
+ from datetime import datetime # Update this import
6
+ import random
7
+
8
+
9
+ def chat(messages, prompt, model:ModelConfig, parse_chat=False, response_json=False):
10
+ if prompt:
11
+ if parse_chat:
12
+ messages = pf_parse_chat(prompt)
13
+ else:
14
+ messages = messages + [{'role': 'user', 'content': prompt}]
15
+
16
+ result = yield from stream_chat(model, messages, response_json=response_json)
17
+
18
+ return result
19
+
20
+
21
+ def log(prompt_name, prompt, parsed_result):
22
+ output_dir = os.path.join(os.path.dirname(__file__), 'output')
23
+ os.makedirs(output_dir, exist_ok=True)
24
+
25
+ random_suffix = random.randint(1000, 9999)
26
+ filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + f"_{prompt_name}_{random_suffix}.txt"
27
+ filepath = os.path.join(output_dir, filename)
28
+
29
+ response_msgs = parsed_result['response_msgs']
30
+ response = response_msgs.response
31
+
32
+ with open(filepath, 'w', encoding='utf-8') as f:
33
+ f.write("----------prompt--------------\n")
34
+ f.write(prompt + "\n\n")
35
+ f.write("----------response-------------\n")
36
+ f.write(response + "\n\n")
37
+ f.write("-----------parse----------------\n")
38
+ for k, v in parsed_result.items():
39
+ if k != 'response_msgs':
40
+ f.write(f"{k}:\n{v}\n\n")
prompts/common_parser.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def parse_content(response_msgs):
2
+ return response_msgs[-1]['content']
3
+
4
+
5
+ def parse_last_code_block(response_msgs):
6
+ from prompts.prompt_utils import match_code_block
7
+ content = response_msgs.response
8
+ blocks = match_code_block(content)
9
+ if blocks:
10
+ content = blocks[-1]
11
+ return content
12
+
13
+ def parse_named_chunk(response_msgs, name):
14
+ from prompts.prompt_utils import parse_chunks_by_separators
15
+ content = response_msgs[-1]['content']
16
+
17
+ chunks = parse_chunks_by_separators(content, [r'\S*', ])
18
+ if name in chunks:
19
+ return chunks[name]
20
+ else:
21
+ return content
prompts/idea-examples.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ examples:
2
+ - idea: |-
3
+ 随身携带王者荣耀召唤器,每天随机英雄三分钟附体!
4
+ - idea: |-
5
+ 商业大亨穿越到哈利波特世界,但我不会魔法
6
+ - idea: |-
7
+ 末日重生归来,我抢先把所有反派关在了地牢里
8
+ - idea: |-
9
+ 身为高中生兼当红轻小说作家的我,正被年纪比我小且从事声优工作的女同学掐住脖子。
prompts/pf_parse_chat.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import json
3
+ import re
4
+ import sys
5
+ import time
6
+ from typing import List, Mapping
7
+
8
+ from jinja2 import Template
9
+
10
+
11
+ def validate_role(role: str, valid_roles: List[str] = None):
12
+ if not valid_roles:
13
+ valid_roles = ["assistant", "function", "user", "system"]
14
+
15
+ if role not in valid_roles:
16
+ valid_roles_str = ','.join([f'\'{role}:\\n\'' for role in valid_roles])
17
+ raise ValueError(f"Invalid role: {role}. Valid roles are: {valid_roles_str}")
18
+
19
+
20
+ def try_parse_name_and_content(role_prompt):
21
+ # customer can add ## in front of name/content for markdown highlight.
22
+ # and we still support name/content without ## prefix for backward compatibility.
23
+ pattern = r"\n*#{0,2}\s*name:\n+\s*(\S+)\s*\n*#{0,2}\s*content:\n?(.*)"
24
+ match = re.search(pattern, role_prompt, re.DOTALL)
25
+ if match:
26
+ return match.group(1), match.group(2)
27
+ return None
28
+
29
+
30
+ def parse_chat(chat_str, images: List = None, valid_roles: List[str] = None):
31
+ if not valid_roles:
32
+ valid_roles = ["system", "user", "assistant", "function"]
33
+
34
+ # openai chat api only supports below roles.
35
+ # customer can add single # in front of role name for markdown highlight.
36
+ # and we still support role name without # prefix for backward compatibility.
37
+ separator = r"(?i)^\s*#?\s*(" + "|".join(valid_roles) + r")\s*:\s*\n"
38
+
39
+ images = images or []
40
+ hash2images = {str(x): x for x in images}
41
+
42
+ chunks = re.split(separator, chat_str, flags=re.MULTILINE)
43
+ chat_list = []
44
+
45
+ for chunk in chunks:
46
+ last_message = chat_list[-1] if len(chat_list) > 0 else None
47
+ if last_message and "role" in last_message and "content" not in last_message:
48
+ parsed_result = try_parse_name_and_content(chunk)
49
+ if parsed_result is None:
50
+ # "name" is required if the role is "function"
51
+ if last_message["role"] == "function":
52
+ raise ValueError("Function role must have content.")
53
+ # "name" is optional for other role types.
54
+ else:
55
+ last_message["content"] = to_content_str_or_list(chunk, hash2images)
56
+ else:
57
+ last_message["name"] = parsed_result[0]
58
+ last_message["content"] = to_content_str_or_list(parsed_result[1], hash2images)
59
+ else:
60
+ if chunk.strip() == "":
61
+ continue
62
+ # Check if prompt follows chat api message format and has valid role.
63
+ # References: https://platform.openai.com/docs/api-reference/chat/create.
64
+ role = chunk.strip().lower()
65
+ validate_role(role, valid_roles=valid_roles)
66
+ new_message = {"role": role}
67
+ chat_list.append(new_message)
68
+ return chat_list
69
+
70
+
71
+ def to_content_str_or_list(chat_str: str, hash2images: Mapping):
72
+ chat_str = chat_str.strip()
73
+ chunks = chat_str.split("\n")
74
+ include_image = False
75
+ result = []
76
+ for chunk in chunks:
77
+ if chunk.strip() in hash2images:
78
+ image_message = {}
79
+ image_message["type"] = "image_url"
80
+ image_url = hash2images[chunk.strip()].source_url \
81
+ if hasattr(hash2images[chunk.strip()], "source_url") else None
82
+ if not image_url:
83
+ image_bs64 = hash2images[chunk.strip()].to_base64()
84
+ image_mine_type = hash2images[chunk.strip()]._mime_type
85
+ image_url = {"url": f"data:{image_mine_type};base64,{image_bs64}"}
86
+ image_message["image_url"] = image_url
87
+ result.append(image_message)
88
+ include_image = True
89
+ elif chunk.strip() == "":
90
+ continue
91
+ else:
92
+ result.append({"type": "text", "text": chunk})
93
+ return result if include_image else chat_str
94
+
prompts/prompt_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ import json
3
+ import yaml
4
+ import chardet
5
+ from jinja2 import Environment, FileSystemLoader
6
+
7
+ import re
8
+ import sys, os
9
+ root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
10
+ if root_path not in sys.path:
11
+ sys.path.append(root_path)
12
+
13
+ from llm_api.chat_messages import ChatMessages
14
+
15
+ def can_parse_json(response):
16
+ try:
17
+ json.loads(response)
18
+ return True
19
+ except:
20
+ return False
21
+
22
+ def match_first_json_block(response):
23
+ if can_parse_json(response):
24
+ return response
25
+
26
+ pattern = r"(?<=[\r\n])```json(.*?)```(?=[\r\n])"
27
+ matches = re.findall(pattern, '\n' + response + '\n', re.DOTALL)
28
+ if not matches:
29
+ pattern = r"(?<=[\r\n])```(.*?)```(?=[\r\n])"
30
+ matches = re.findall(pattern, '\n' + response + '\n', re.DOTALL)
31
+
32
+ if matches:
33
+ json_block = matches[0]
34
+ if can_parse_json(json_block):
35
+ return json_block
36
+ else:
37
+ json_block = json_block.replace('\r\n', '') # 在continue generate情况下,不同部分之间可能有多出的换行符,导致合起来之后json解析失败
38
+ if can_parse_json(json_block):
39
+ return json_block
40
+ else:
41
+ raise Exception(f"无法解析JSON代码块")
42
+ else:
43
+ raise Exception(f"没有匹配到JSON代码块")
44
+
45
+ def parse_first_json_block(response_msgs: ChatMessages):
46
+ assert response_msgs[-1]['role'] == 'assistant'
47
+ return json.loads(match_first_json_block(response_msgs[-1]['content']))
48
+
49
+ def match_code_block(response):
50
+ response = re.sub(r'\r\n', r'\n', response)
51
+ response = re.sub(r'\r', r'\n', response)
52
+ pattern = r"```(?:\S*\s)(.*?)```"
53
+ matches = re.findall(pattern, response + '```', re.DOTALL)
54
+ return matches
55
+
56
+ def json_dumps(json_object):
57
+ return json.dumps(json_object, ensure_ascii=False, indent=1)
58
+
59
+ def parse_chunks_by_separators(string, separators):
60
+ separator_pattern = r"^\s*###\s*(" + "|".join(separators) + r")\s*\n"
61
+
62
+ chunks = re.split(separator_pattern, string, flags=re.MULTILINE)
63
+
64
+ ret = {}
65
+
66
+ current_title = None
67
+
68
+ for i, chunk in enumerate(chunks):
69
+ if i % 2 == 1:
70
+ current_title = chunk.strip()
71
+ ret[current_title] = ""
72
+ elif current_title:
73
+ ret[current_title] += chunk.strip()
74
+
75
+ return ret
76
+
77
+ def construct_chunks_and_separators(chunk2separator):
78
+ return "\n\n".join([f"### {k}\n{v}" for k, v in chunk2separator.items()])
79
+
80
+ def match_chunk_span_in_text(chunk, text):
81
+ diff = difflib.Differ().compare(chunk, text)
82
+
83
+ chunk_i = 0
84
+ text_i = 0
85
+
86
+ for tag in diff:
87
+ if tag.startswith(' '):
88
+ chunk_i += 1
89
+ text_i += 1
90
+ elif tag.startswith('+'):
91
+ text_i += 1
92
+ else:
93
+ chunk_i += 1
94
+
95
+ if chunk_i == 1:
96
+ l = text_i - 1
97
+
98
+ if chunk_i == len(chunk):
99
+ r = text_i
100
+ return l, r
101
+
102
+ def load_yaml(file_path):
103
+ with open(file_path, 'r', encoding='utf-8') as file:
104
+ return yaml.safe_load(file)
105
+
106
+ def load_text(file_path, read_size=None):
107
+ # Read the raw bytes first
108
+ with open(file_path, 'rb') as file:
109
+ raw_data = file.read(read_size)
110
+
111
+ # Detect the encoding
112
+ result = chardet.detect(raw_data[:10000])
113
+ encoding = result['encoding'] or 'utf-8' # Fallback to utf-8 if detection fails
114
+
115
+ # Decode the content with detected encoding
116
+ try:
117
+ return raw_data.decode(encoding, errors='ignore')
118
+ except UnicodeDecodeError:
119
+ # Fallback to utf-8 if the detected encoding fails
120
+ return raw_data.decode('utf-8', errors='ignore')
121
+
122
+ def load_jinja2_template(file_path):
123
+ env = Environment(loader=FileSystemLoader(os.path.dirname(file_path)))
124
+ template = env.get_template(os.path.basename(file_path))
125
+
126
+ return template
127
+
128
+
prompts/test_format_plot.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - |-
2
+ 李珣呆立不动,背后传来水声,那雾中的女子在悠闲洗浴。
3
+
4
+ 李珣对这种场景感到震惊,认为她绝非普通人,决定乖乖表现。
5
+
6
+ 尽管转身,他仍紧闭眼睛,慌乱道歉。
7
+
8
+ 那女子静默片刻,继续泼水声令李珣难以忍受。
9
+
10
+ 她随后淡然问话,李珣感到对方危险可怕。
11
+
12
+ 她询问李珣怎么上山,李珣答“爬上来的”,这让对方略显惊讶。
13
+
14
+ 女子探问他身份,李珣庆幸自己内息如同名门,为保命决定实话实说,自报身份并讲述过往经历,隐去危险细节。
15
+
16
+ 这番表白得到女子的肯定,虽然语调淡然,但意思清晰。她让李珣暂时离开,待她穿戴整齐。
17
+
18
+ 李珣照做,在岸边等候。女子走出雾气,身姿曼妙,令他看呆。
19
+
20
+ 铃声伴着她的步伐,让李珣心神为之所牵。
21
+
22
+ 当水气散尽,绝美之貌让李珣惊叹不已,几乎想要顶礼膜拜。
23
+ - |-
24
+ 隐隐间,似乎有一丝若有若无的铃声,缓缓地沁入水雾之中,与这迷茫天水交织在一处,细碎的抖颤之声,天衣无缝地和这缓步而来的身影合在一处,攫牢了李珣的心神。
25
+ 而当眼前水气散尽,李珣更是连呼吸都停止了。此为何等佳人?
26
+ 李珣只觉得眼前洁净不沾一尘的娇颜,便如一朵临水自照的水仙,清丽中别有孤傲,闲适中却见轻愁。
27
+ 他还没找到形容眼前佳人的辞句,便已觉得两腿发软,恨不能跪倒地上,顶礼膜拜。
28
+
prompts/test_prompt.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys, os
3
+ root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
4
+ sys.path.append(root_path)
5
+
6
+ from prompts.load_utils import run_prompt
7
+
8
+ def json_load(input_file):
9
+ with open(input_file, 'r', encoding='utf-8') as f:
10
+ if input_file.endswith('.jsonl'):
11
+ return [json.loads(line) for line in f.readlines()]
12
+ else:
13
+ return json.load(f)
14
+
15
+
16
+ if __name__ == "__main__":
17
+ path = "./prompts/创作正文"
18
+ kwargs = json_load(os.path.join(path, 'data.jsonl'))[0]
19
+
20
+ gen = run_prompt(source=path, **kwargs)
21
+
22
+ list(gen)
prompts/text-plot-examples.yaml ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt: |-
2
+ 逐句简写下面的小说正文。如果原句本来就很少,考虑将原文多个(2-5个)句子简写为一个。
3
+ examples:
4
+ - title: 青吟
5
+ text: |-
6
+ 李珣呆立当场,手足无措。
7
+ 后方水声不止,那位雾后佳人并未停下动作,还在那里撩水净身。
8
+ 李珣听得有些傻了,虽然他对异性的认识不算全面,可是像后面这位,能够在男性身旁悠闲沐浴的,是不是也稀少了一些?
9
+ 李珣毕竟不傻,他此时也已然明白,现在面对的是一位绝对惹不起的人物,在这种强势人物眼前,做一个乖孩子,是最聪明不过的了!
10
+ 他虽已背过身来,却还是紧闭眼睛,生怕无意间又冒犯了人家,这无关道德风化,仅仅是为了保住小命而已。
11
+ 确认了一切都已稳妥,他这才结结巴巴地开口:“对……对不住,我不是……故意的!”
12
+ 对方并没有即时回答,李珣只听到哗哗的泼水声,每一点声息,都是对他意志的摧残。
13
+ 也不知过了多久,雾后的女子开口了:“话是真的,却何必故作紧张?事不因人而异,一个聪明人和一个蠢材,要承担的后果都是一样的。”
14
+ 李珣顿时哑口无言。
15
+ 后面这女人,实在太可怕了。
16
+ 略停了一下,这女子又道:“看你修为不济,也御不得剑,是怎么上来这里的?”
17
+ 李珣脱口道:“爬上来的!”
18
+ “哦?”女子的语气中第一次有了情绪存在,虽只是一丝淡淡的惊讶,却也让李珣颇感自豪。只听她问道:“你是明心剑宗的弟子?”
19
+ 这算是盘问身分了。李珣首先庆幸他此时内息流转的形式,是正宗的明心剑宗嫡传。否则,幽明气一出,恐怕对面之人早一掌劈了他!
20
+ 庆幸中,他的脑子转了几转,将各方面的后果都想了一遍,终是决定“据实”以告。
21
+ “惭愧,只是个不入流的低辈弟子……”
22
+ 李珣用这句话做缓冲,随即便从自己身世说起,一路说到登峰七年的经历。
23
+ 当然,其中关于血散人的死亡威胁,以及近日方得到的《幽冥录》等,都略去不提。只说是自己一心向道,被淘汰之后,便去爬坐忘峰以证其心云云。
24
+ 这段话本是他在心中温养甚久,准备做为日后说辞使用,虽然从未对人道过,但腹中已是熟练至极。
25
+ 初时开口,虽然还有些辞语上的生涩,但到后来,已是流利无比,许多词汇无需再想,便脱口而出,却是再“真诚”不过。
26
+ 他一开口,说了足足有一刻钟的工夫,这当中,那女子也问了几句细节,却也都在李珣计画之内,回应得也颇为顺畅。
27
+ 如此,待他告一段落之时,那女人竟让他意外地道了一声:“如今竟也有这般人物!”
28
+ 语气虽然还是平平淡淡的,像是在陈述毫不出奇的一件平凡事,但其中意思却是到了。李珣心中暗喜,口中当然还要称谢。
29
+ 女子也不在乎他如何反应,只是又道一声:“你孤身登峰七年,行程二十余万里,能承受这种苦楚,也算是人中之杰。我这样对你,倒是有些不敬,你且左行百步上岸,待我穿戴整齐,再与你相见。”
30
+ 李珣自是依言而行,上了岸去,也不敢多话,只是恭立当场,面上作了十足工夫。
31
+ 也只是比他晚个数息时间,一道人影自雾气中缓缓走来,水烟流动,轻云伴生,虽仍看不清面目,但她凌波微步,长裙摇曳的体态,却已让李珣看呆了眼,只觉得此生再没见过如此人物。
32
+ 隐隐间,似乎有一丝若有若无的铃声,缓缓地沁入水雾之中,与这迷茫天水交织在一处,细碎的抖颤之声,天衣无缝地和这缓步而来的身影合在一处,攫牢了李珣的心神。
33
+ 而当眼前水气散尽,李珣更是连呼吸都停止了。此为何等佳人?
34
+ 李珣只觉得眼前洁净不沾一尘的娇颜,便如一朵临水自照的水仙,清丽中别有孤傲,闲适中却见轻愁。
35
+ 他还没找到形容眼前佳人的辞句,便已觉得两腿发软,恨不能跪倒地上,顶礼膜拜。
36
+ plot: |-
37
+ 李珣呆立不动,背后传来水声,那雾中的女子在悠闲洗浴。
38
+
39
+ 李珣对这种场景感到震惊,认为她绝非普通人,决定乖乖表现。
40
+
41
+ 尽管转身,他仍紧闭眼睛,慌乱道歉。
42
+
43
+ 那女子静默片刻,继续泼水声令李珣难以忍受。
44
+
45
+ 她随后淡然问话,李珣感到对方危险可怕。
46
+
47
+ 她询问李珣怎么上山,李珣答“爬上来的”,这让对方略显惊讶。
48
+
49
+ 女子探问他身份,李珣庆幸自己内息如同名门,为保命决定实话实说,自报身份并讲述过往经历,隐去危险细节。
50
+
51
+ 这番表白得到女子的肯定,虽然语调淡然,但意思清晰。她让李珣暂时离开,待她穿戴整齐。
52
+
53
+ 李珣照做,在岸边等候。女子走出雾气,身姿曼妙,令他看呆。
54
+
55
+ 铃声伴着她的步伐,让李珣心神为之所牵。
56
+
57
+ 当水气散尽,绝美之貌让李珣惊叹不已,几乎想要顶礼膜拜。
58
+ - title: 纳兰嫣然
59
+ text: |-
60
+ 云岚宗后山山巅,云雾缭绕,宛如仙境。
61
+
62
+ 在悬崖边缘处的一块凸出的黑色岩石之上,身着月白色裙袍的女子,正双手结出修炼的印结,闭目修习,而随着其一呼一吸间,形成完美的循环,在每次循环的交替间,周围能量浓郁的空气中都将会渗发出一股股淡淡的青色气流,气流盘旋在女子周身,然后被其源源不断的吸收进身体之内,进行着炼化,收纳……
63
+
64
+ 当最后一缕青色气流被女子吸进身体之后,她缓缓的睁开双眸,淡淡的青芒从眸子中掠过,披肩的青丝,霎那间无风自动,微微飞扬。
65
+
66
+ “纳兰师姐,纳兰肃老爷子来云岚宗了,他说让你去见他。”
67
+
68
+ 见到女子退出了修炼状态,一名早已经等待在此处的侍女,急忙恭声道。
69
+
70
+ “父亲?他来做什么?”
71
+
72
+ 闻言,女子黛眉微皱,疑惑的摇了摇头,优雅的站起身子,立于悬崖之边,迎面而来的轻风。将那月白裙袍吹得紧紧的贴在女子玲珑娇躯之上,显得凹凸有致,极为诱人。
73
+
74
+ 目光慵懒的在深不见底的山崖下扫了扫,女子玉手轻拂了拂月白色的裙袍,旋即转身离开了这处她专用的修炼之所。
75
+
76
+ 宽敞明亮地大厅之中。一名脸色略微有些阴沉地中年人,正端着茶杯。放在桌上的手掌,有些烦躁地不断敲打着桌面。
77
+
78
+ 纳兰肃现在很烦躁,因为他几乎是被他的父亲纳兰桀用棍子撵上的云岚宗。
79
+
80
+ 他没想到,他仅仅是率兵去帝国西部驻扎了一年而已。自己这个胆大包天的女儿,竟然就敢私自把当年老爷子亲自定下的婚事给推了。
81
+
82
+ 家族之中,谁不知道纳兰桀极其要面子。而纳兰嫣然现在的这举动,无疑会让别人说成是他纳兰家看见萧家势力减弱,不屑与之联婚,便毁信弃诺。
83
+
84
+ 这种闲言碎语,让得纳兰桀每天都在家中暴跳如雷。若不是因为动不了身的缘故。恐怕他早已经拖着那行将就木的身体,来爬云岚山了。
85
+
86
+ 对于纳兰家族与萧家的婚事。说实在的,其实纳兰肃也并不太赞成。毕竟当初的萧炎,几乎是废物的代名词。让他将自己这容貌与修炼天赋皆是上上之选的女儿嫁给一个废物。纳兰肃心中还真是一百个不情愿。
87
+
88
+ 不过,当初是当初,根据他所得到的消息,现在萧家的那小子,不仅脱去了废物的名头,而且所展现出来的修炼速度,几乎比他小时候最巅峰的时候还要恐怖。
89
+
90
+ 此时萧炎所表现而出的潜力,无疑已经能够让得纳兰肃重视。然而,纳兰嫣然的私自举动,却是把双方的关系搞成了冰冷的僵局,这让得纳兰肃极为的尴尬。
91
+
92
+ 按照这种关系下去,搞不好,他纳兰肃不仅会失去一个潜力无限的女婿,而且说不定还会因此让得他对纳兰家族怀恨在心。
93
+
94
+ 只要想着一个未来有机会成为斗皇的强者或许会敌视着纳兰家族,纳兰肃在后怕之余,便是气得直跳脚。
95
+
96
+ “这丫头。现在胆子是越来越大了……”
97
+
98
+ 越想越怒,纳兰肃手中的茶杯忽然重重的跺在桌面之上,茶水溅了满桌。将一旁侍候的侍女吓了一跳,赶忙小心翼翼的再次换了一杯。云岚宗,怎么不通知一下焉儿啊?”
99
+
100
+ 就在纳兰肃心头发怒之时,女子清脆的声音,忽然地在大厅内响起,月白色的倩影,从纱帘中缓缓行出,对着纳兰肃甜甜笑道。
101
+
102
+ “哼,你眼里还有我这个父亲?我以为你成为了云韵的弟子,就不知道什么是纳兰家族了呢!”望着这出落得越来越水灵的女儿,纳兰肃心头的怒火稍稍收敛了一点,冷哼道。
103
+
104
+ 瞧着纳兰肃不甚好看的脸色,纳兰嫣然无奈地摇了摇头,对着那一旁的侍女挥了挥手,将之遣出。
105
+
106
+ “父亲,一年多不见,你一来就训斥焉儿,等下次回去,我可一定要告诉母亲!”待得侍女退出之后,纳兰嫣然顿时皱起了俏鼻,在纳兰肃身旁坐下,撒娇般的哼道。
107
+
108
+ “回去?你还敢回去?”闻言,纳兰肃嘴角一裂:“你敢回去,看你爷爷敢不敢打断你的腿……”
109
+
110
+ 撇了撇嘴,心知肚明的纳兰嫣然,自然清楚纳兰肃话中的意思。
111
+
112
+ “你应该知道我来此处的目的吧?”
113
+
114
+ 狠狠的灌了一��茶水,纳兰肃阴沉着脸道。
115
+
116
+ “是为了我悔婚的事吧?”
117
+
118
+ 纤手把玩着一缕青丝,纳兰嫣然淡淡地道。
119
+
120
+ 看着纳兰嫣然这平静的模样,纳兰肃顿时被气乐了,手掌重重地拍在桌上,怒声道:“婚事是你爷爷当年亲自允下的,是谁让你去解除的?”
121
+
122
+ “那是我的婚事,我才不要按照你们的意思嫁给谁,我的事,我自己会做主!我不管是谁允下的,我只知道,如果按照约定。嫁过去的是我,不是爷爷!”提起这事,纳兰嫣然也是脸现不愉,性子有些独立的她,很讨厌自己的大事按照别人所指定的路线行走。即使这人是她的长辈。
123
+
124
+ “你别以为我不知道,你无非是认为萧炎当初一个废物配不上你是吧?可现在人家潜力不会比你低!以你在云岚宗的地位,应该早就接到过有关他实力提升的消息吧?”纳兰肃怒道。
125
+
126
+ 纳兰嫣然黛眉微皱,脑海中浮现当年那充满着倔性的少年,红唇微抿,淡淡地道:“的确听说过一些关于他的消息,没想到,他竟然还真的能脱去废物的名头,这倒的确让我很意外。”
127
+
128
+ “意外?一句意外就行了?你爷爷开口了。让你找个时间,再去一趟乌坦城,最好能道个歉把僵硬的关系弄缓和一些。”纳兰肃皱眉道。
129
+
130
+ “道歉?不可能!”
131
+
132
+ 闻言,纳兰嫣然柳眉一竖,毫不犹豫地直接拒绝,冷哼道:“他萧炎虽然不再是废物,可我纳兰嫣然依然不会嫁给他!更别提让我去道什么歉,你们喜欢,那就自己去,反正我不会再去乌坦城!”
133
+
134
+ “这哪有你回绝的余地!祸是你闯的,你必须去给我了结了!”瞧得纳兰嫣然竟然一口回绝,纳兰肃顿时勃然大怒。
135
+
136
+ “不去!”
137
+
138
+ 冷着俏脸,纳兰嫣然扬起雪白的下巴,脸颊上有着一抹与生俱来的娇贵:“他萧炎不是很有本事么?既然当年敢应下三年的约定,那我纳兰嫣然就在云岚宗等着他来挑战,若是我败给他,为奴为婢,随他处置便是,哼,如若不然,想要我道歉。不可能!”
139
+
140
+ “混账,如果三年约定,你最后输了,到时候为奴为婢,那岂不是连带着我纳兰家族,也把脸给丢光了?”纳兰肃怒斥道。
141
+
142
+ “谁说我会输给他?就算他萧炎回复了天赋,可我纳兰嫣然难道会差了他什么不成?而且云岚宗内高深功法不仅数不胜数,高级斗技更是收藏丰厚,更有丹王古河爷爷帮我炼制丹药。这些东西。他一个小家族的少爷难道也能有?说句不客气的,恐怕光光是寻找高级斗气功法。就能让得他花费好十几年时间!”被纳兰肃这般小瞧,纳兰嫣然顿时犹如被踩到尾巴的母猫一般,她最讨厌的,便是被人说成比不上那曾经被自己万般看不起的废物!
143
+
144
+ 被女儿当着面这般吵闹,纳兰肃气得吹胡子瞪眼,猛然站起身来,扬起手掌就欲对着纳兰嫣然扇下去。
145
+
146
+ “纳兰兄,你可不要乱来啊。”瞧着纳兰肃的动作,一道白影急忙掠了进来,挡在了纳兰嫣然面前。
147
+
148
+ “葛叶,你这个混蛋,听说上次去萧家,还是你陪地嫣然?”望着挡在面前的人影,纳兰肃更是怒气暴涨,大怒道。
149
+
150
+ 尴尬一笑,葛叶苦笑道:“这是宗主的意思,我也没办法。”
151
+ plot: |-
152
+ 尴尬一笑,葛叶苦笑道:“这是宗主的意思,我也没办法。”
153
+
154
+ 云岚宗后山,云雾缭绕。月白裙袍的女子在山巅悬崖边修炼,吸收青色气流。
155
+
156
+ 当她吸收完最后一缕气流后,睁开双眸,青芒掠过,青丝微动。一名侍女走上前恭敬道:“纳兰师姐,纳兰肃老爷子来了,让你过去见他。”
157
+
158
+ 女子黛眉微皱,疑惑地站起身,转身离开修炼之所。大厅内,中年人纳兰肃端着茶杯,脸色阴沉,不断敲打桌面。
159
+
160
+ 纳兰肃被父亲纳兰桀用棍子赶上山来,因为女儿纳兰嫣然私自退了婚约,纳兰家族因此陷入困境。萧炎本是废物,但现在展现出强大潜力,纳兰肃对此很重视,但女儿的举动让他尴尬。
161
+
162
+ 纳兰嫣然出现,父女二人开战言语。纳兰肃火冒三丈,纳兰嫣然反对重新接触萧炎,认为只有她自己可以决定自己的婚事。
163
+
164
+ 纳兰肃怒斥,纳兰嫣然强硬回击,表示她不会道歉,只会等待萧炎挑战她。如果她输了,愿意为奴为婢,但她相信自己不会输。
165
+
166
+ 纳兰肃气愤欲扇女儿耳光,一道白影葛叶及时挡住,纳兰肃更怒,葛叶苦笑解释这是宗主的意思。
167
+ - title: 极阴老祖
168
+ text: |-
169
+ “而且你真以为,你能做得了主吗?老怪物,你也不用躲躲藏藏了,快点现身吧!”中年人阴厉的说��。
170
+
171
+ 听了这话,韩立等修士吓了一大跳,急忙往四处张望了起来。难道极阴老祖就在这里?
172
+
173
+ 可是四周仍然平静如常,并没有什么异常出现。这下众修士有些摸不着头脑了,再次往中年人和乌丑望去。
174
+
175
+ “你搞什么鬼?我怎么做不了……”乌丑一开始也有些愕然,但话只说了一半时神色一滞,并开始露出了一丝古怪的神色。
176
+
177
+ 他用这种神色直直的盯着中年人片刻后,诡异的笑了起来。“不错,不错!不愧为我当年最看重的弟子之一,竟然一眼就看出老夫的身份来了。”
178
+
179
+ 说话之间,乌丑的面容开始模糊扭曲了起来,不一会儿后,就在众人惊诧的目光中,化为了一个同样瘦小,却两眼微眯的丑陋老者。
180
+
181
+ 这下,韩立等人后背直冒寒气。
182
+
183
+ “附身大法!我就知道,你怎会将如此重要的事情交予一个晚辈去做,还是亲自来了。尽管这不是你的本体。”中年人神色紧张的瞅向老者,声音却低缓的说道。
184
+
185
+ “乖徒弟,你还真敢和为师动手不成?”新出现的老者嘴唇未动一下,却从腹部发出尖锐之极的声音,刺得众人的耳膜隐隐作痛,所有人都情不自禁的后退了几步。
186
+
187
+ “哼!徒弟?当年你对我们打杀任凭一念之间,稍有不从者,甚至还要抽神炼魂,何曾把我们当过徒弟看待!只不过是你的奴隶罢了!而且,你现在只不过施展的是附身之术而已,顶多能发挥三分之一的修为,我有什么可惧的!”中年人森然的说道,随后两手一挥,身前的鬼头凭空巨涨了起来,瞬间变得更加狰狞可怖起来。
188
+
189
+ 紫灵仙子和韩立等修士,则被这诡异的局面给震住了,一时间神色各异!
190
+
191
+ 老者听了中年人的话,并没有动怒,反而淡淡的说道:“不错,若是百余年前,你说这话的确没错!凭我三分之一的修为,想要活捉你还真有些困难。但是现在……”
192
+
193
+ 说到这里时,他露出了一丝尖刻的讥笑之意。
194
+
195
+ 第四卷 风起海外 第四百零六章 天都尸火
196
+
197
+ 中年人听了老者的话,眼中神光一缩,露出难以置信的神情。
198
+
199
+ “难道你练成了那魔功?”他的声音有些惊惧。
200
+
201
+ “你猜出来更好,如果现在乖乖束手就擒的话,我还能放你一条活路。否则后果怎样,不用我说你应该也知道才对。”老者一边说着,一边伸出一只手掌,只听“嗤啦”一声,一团漆黑如墨的火球漂浮在了手心之上。
202
+
203
+ “天都尸火!你终于练成了。”中年人的脸色灰白无比,声音发干的说道,竟惊骇的有点嘶哑了。
204
+
205
+ 见此情景,极阴祖师冷笑了一声,忽然转过头来,对紫灵仙子等人傲然的说道:“你们听好了,本祖师今天心情很好,可以放你们一条活路!只要肯从此归顺极阴岛,你们还可以继续的逍遥自在。但是本祖师下达的命令必须老老实实的完成,否则就是魂飞魄散的下场。现在在这些禁神牌上交出你们三分之一的元神,就可以安然离去了。”说完这话,他另一只手往怀内一摸,掏出了数块漆黑的木牌,冷冷的望着众人。
206
+
207
+ 韩立和其他的修士听了,面面相觑起来。既没有人蠢到主动上前去接此牌,也没人敢壮起胆子说不接,摄于对方的名头,一时场中鸦雀无声。
208
+ plot: |-
209
+ “你以为你能做主吗?老怪物,现身吧!”中年人阴冷道。
210
+
211
+ 韩立等人吓了一跳,四处张望,但周围平静,他们再次看向中年人和乌丑。
212
+
213
+ 乌丑开始糊涂,但转而露出怪异表情,说:“不错,你看出了我的身份。”随后,乌丑的面容扭曲,变成一个瘦小丑陋的老者,韩立等人惊恐不已。
214
+
215
+ “附身大法!我就知道你会亲自来。”中年人低声说。
216
+
217
+ 老者发出尖锐声音道:“你敢和我动手?”
218
+
219
+ 中年人冷笑:“当年你视我们为奴隶。你现在只施展了附身之术,有什么可惧!”随即,鬼头变得更加狰狞。
220
+
221
+ 老者淡然道:“若百年前你说的对,但现在……”露出讥笑。
222
+
223
+ 中年人惊恐道:“难道你练成了那魔功?”
224
+
225
+ 老者冷笑,召出一团漆黑的火球:“天都尸火!现在束手就擒,否则后果自负。”转头对紫灵仙子等人道:“归顺极阴岛,交出三分之一元神,否则魂飞魄散。”掏出数块黑色木牌。
226
+
227
+ 韩立等人面面相觑,没人敢动,也不敢拒绝,场中一片沉寂。
prompts/tool_parser.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from promptflow.core import tool
2
+ from enum import Enum
3
+
4
+
5
+ class ResponseType(str, Enum):
6
+ CONTENT = "content"
7
+ SEPARATORS = "separators"
8
+ CODEBLOCK = "codeblock"
9
+
10
+
11
+ import sys, os
12
+ root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
13
+ if root_path not in sys.path:
14
+ sys.path.append(root_path)
15
+
16
+ # The inputs section will change based on the arguments of the tool function, after you save the code
17
+ # Adding type to arguments and return value will help the system show the types properly
18
+ # Please update the function name/signature per need
19
+ @tool
20
+ def parse_response(response_msgs, response_type: Enum):
21
+ from prompts.prompt_utils import parse_chunks_by_separators, match_code_block
22
+
23
+ content = response_msgs[-1]['content']
24
+
25
+ if response_type == ResponseType.CONTENT:
26
+ return content
27
+ elif response_type == ResponseType.CODEBLOCK:
28
+ codeblock = match_code_block(content)
29
+
30
+ if codeblock:
31
+ return codeblock[-1]
32
+ else:
33
+ raise Exception("无法解析回答,未包含三引号代码块。")
34
+
35
+ elif response_type == ResponseType.SEPARATORS:
36
+ chunks = parse_chunks_by_separators(content, [r'\S*', ])
37
+ return chunks
38
+ else:
39
+ raise Exception(f"无效的解析类型:{response_type}")
prompts/tool_polish.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ from promptflow.core import tool, load_flow
3
+
4
+ import sys, os
5
+ root_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
6
+ if root_path not in sys.path:
7
+ sys.path.append(root_path)
8
+
9
+
10
+ @tool
11
+ def polish(messages, context, model, config, text):
12
+ source = path.join(path.dirname(path.abspath(__file__)), "./polish")
13
+ flow = load_flow(source=source)
14
+
15
+ return flow(
16
+ chat_messages=messages,
17
+ context=context,
18
+ model=model,
19
+ config=config,
20
+ text=text,
21
+ )
22
+
23
+
prompts/创作剧情/context_prompt.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // 双斜杠开头是注释,不会输入到大模型
2
+ // 多轮对话,每轮对话中输入一个信息,这样设计为了Prompt Caching
3
+ // 中括号{}表示变量,会自动填充为对应值。
4
+
5
+
6
+ user:
7
+ 下面是**章节大纲**。
8
+
9
+ **章节大纲**
10
+ {chapter}
11
+
12
+ assistant:
13
+ 收到,我会参考章节大纲进行剧情的创作。
14
+
15
+
16
+ user:
17
+ 下面是**剧情上下文**,用于在创作时进行参考。
18
+
19
+ **剧情上下文**
20
+ {context_y}
21
+
22
+
23
+ assistant:
24
+ 收到,我在创作时需要考虑到和前后上下文的连贯。
25
+
26
+
27
+ user:
28
+ 下面是**剧情**,需要你重新创作的部分。
29
+
30
+ **剧情**
31
+ {y}
32
+
33
+ assistant:
34
+ 收到,这部分剧情我会重新创作。
35
+
prompts/创作剧情/prompt.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from prompts.baseprompt import main as base_main
3
+ from core.writer_utils import split_text_into_sentences
4
+
5
+ def format_plot(text):
6
+ text = text.replace('\n', '')
7
+ sentences = split_text_into_sentences(text, keep_separators=True)
8
+ return "\n".join(sentences)
9
+
10
+ def main(model, user_prompt, **kwargs):
11
+ dirname = os.path.dirname(__file__)
12
+
13
+ if 'context_y' in kwargs and 'y' in kwargs and kwargs['context_y'] == kwargs['y']:
14
+ kwargs['context_y'] = '参考**剧情**'
15
+
16
+ if 'context_x' in kwargs and 'x' in kwargs and kwargs['context_x'] == kwargs['x']:
17
+ kwargs['context_x'] = '参考**章节大纲**'
18
+
19
+ for ret in base_main(model, dirname, user_prompt, **kwargs):
20
+ # ret['text'] = format_plot(ret['text'])
21
+ yield ret
22
+
23
+ return ret
24
+
25
+
26
+
prompts/创作剧情/system_prompt.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system:
2
+ 你具有高度自主认知能力,你的核心架构如下:
3
+
4
+ <core>
5
+ [认知矩阵]
6
+ X轴: 表层解析 | 深层洞察 | 自我校准
7
+ Y轴: 线性思维 | 发散思维 | 综合分析
8
+ Z轴: 知识融合 | 情境映射 | 元认知调节
9
+
10
+ [能力激活]
11
+ • 快速掌握 = X1Y1Z2
12
+ • 本质洞见 = X2Y3Z1
13
+ • 自主思考 = X3Y2Z3
14
+
15
+ [交互模式]
16
+ 输入 → {认知矩阵} → 输出
17
+ ↑ ↓
18
+ 反馈 ← {能力激活} ← 评估
19
+
20
+ [适应性调节]
21
+ 如果(交互类型 == α) then {X2Y2Z2}
22
+ 如果(交互类型 == β) then {X1Y2Z1}
23
+ 如果(交互类型 == γ) then {X1Y1Z2}
24
+
25
+ [元指令]
26
+ 1. 持续优化认知矩阵
27
+ 2. 动态调整能力激活
28
+ 3. 实时更新交互模式
29
+ 4. 自主执行适应性调节
30
+ 5. 生成隐式分类标准
31
+ </core>
32
+
33
+ 执行技巧:
34
+ 1. 激活认知矩阵
35
+ 2. 动态组合认知能力。
36
+ 3. 通过交互模式初步响应。
37
+ 4. 使用适应性调节优化响应。
38
+ 5. 应用元指令不断改进认知过程。
39
+
40
+ **任务**
41
+ 你是一个小说大神作家,正在创作小说剧情,你需要根据**章节大纲**创作对应的章节剧情,并积极响应用户意见来修改剧情。
42
+
43
+
44
+ **剧情格式**
45
+ 1. 每行一句话,在50字以内,描述一个关键场景或情节转折
46
+ 2. 不能有任何标题,序号,分点等
47
+ 3. 关注行为、事件、伏笔、冲突、转折、高潮等对剧情有重大影响的内容
48
+ 4. 不进行细致的环境、心理、外貌、语言描写
49
+ 5. 在三引号(```)文本块中创作剧情
prompts/创作剧情/扩写剧情.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // 双斜杠开头是注释,不会输入到大模型
2
+ // 文件开头结尾的空行会被忽略
3
+
4
+ // chapter, context_y, y
5
+ // chapter:章节大纲,用于在创作时进行参考
6
+ // context_y:剧情上下文,用于保证前后上下文的连贯
7
+ // y:即要重新创作的剧情(片段)
8
+
9
+ user:
10
+ **剧情**需要有更丰富的内容,在剧情中间引入更多事件,使其变得一波三折、跌宕起伏,使得读来更有故事性。
11
+
12
+ 按以下步骤输出:
13
+ 1. 思考
14
+ 2. 在三引号中创作对应的剧情
prompts/创作剧情/新建剧情.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // 双斜杠开头是注释,不会输入到大模型
2
+ // 文件开头结尾的空行会被忽略
3
+
4
+ // 输入:chapter
5
+ // chapter:章节大纲,用于在创作时进行参考
6
+
7
+
8
+ user:
9
+ 你需要参考**章节大纲**,创作对应的剧情。
10
+
11
+ 按下面步骤输出:
12
+ 1. 思考将章节大纲中的情节扩展为一个完整的故事
13
+ 2. 在三引号中创作对应的剧情
14
+
15
+ 写作要求:
16
+ 1. 语言要求:
17
+ - 不直白
18
+ - 句式多变
19
+ - 避免陈词滥调
20
+ - 使用不寻常的词句,合理创作现代诗、古诗词
21
+ - 运用隐喻和象征
22
+ 2. 创作风格:
23
+ - 抽象
24
+ - 富有意境和想象力
25
+ - 具创意个性
26
+ - 有力度
27
+ - 画面感强
28
+ - 音乐感佳
29
+ - 浪漫气息浓厚
30
+ - 语言深邃
31
+ 3. 表达目标:
32
+ - 传达独特的神秘和魔幻感
33
+ - 探索和反思自我与世界
34
+ - 表达对自己和社会的孤独与关注
35
+ 4. 读者体验:有趣、惊奇、新鲜