|
|
import json
|
|
|
import time
|
|
|
|
|
|
from flask import Flask, request, Response, jsonify
|
|
|
from flask_cors import CORS
|
|
|
app = Flask(__name__)
|
|
|
CORS(app)
|
|
|
|
|
|
import sys
|
|
|
import os
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
from prompts.baseprompt import clean_txt_content, load_prompt
|
|
|
|
|
|
from core.writer_utils import KeyPointMsg
|
|
|
from core.draft_writer import DraftWriter
|
|
|
from core.plot_writer import PlotWriter
|
|
|
from core.outline_writer import OutlineWriter
|
|
|
|
|
|
from setting import setting_bp
|
|
|
from summary import process_novel
|
|
|
from backend_utils import get_model_config_from_provider_model
|
|
|
from config import MAX_NOVEL_SUMMARY_LENGTH, MAX_THREAD_NUM, ENABLE_ONLINE_DEMO
|
|
|
|
|
|
|
|
|
app.register_blueprint(setting_bp)
|
|
|
|
|
|
|
|
|
BACKEND_HOST = os.environ.get('BACKEND_HOST', '0.0.0.0')
|
|
|
BACKEND_PORT = int(os.environ.get('BACKEND_PORT', 7869))
|
|
|
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
|
def health_check():
|
|
|
return jsonify({
|
|
|
'status': 'healthy',
|
|
|
'timestamp': int(time.time())
|
|
|
}), 200
|
|
|
|
|
|
|
|
|
def load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num) -> DraftWriter:
|
|
|
kwargs = dict(
|
|
|
xy_pairs=chunk_list,
|
|
|
model=get_model_config_from_provider_model(main_model),
|
|
|
sub_model=get_model_config_from_provider_model(sub_model),
|
|
|
)
|
|
|
|
|
|
kwargs['x_chunk_length'] = x_chunk_length
|
|
|
kwargs['y_chunk_length'] = y_chunk_length
|
|
|
kwargs['max_thread_num'] = max_thread_num
|
|
|
match writer_mode:
|
|
|
case 'draft':
|
|
|
kwargs['global_context'] = {}
|
|
|
novel_writer = DraftWriter(**kwargs)
|
|
|
case 'outline':
|
|
|
kwargs['global_context'] = {'summary': global_context}
|
|
|
novel_writer = OutlineWriter(**kwargs)
|
|
|
case 'plot':
|
|
|
kwargs['global_context'] = {'chapter': global_context}
|
|
|
novel_writer = PlotWriter(**kwargs)
|
|
|
case _:
|
|
|
raise ValueError(f"unknown writer: {writer_mode}")
|
|
|
|
|
|
return novel_writer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_names = dict(
|
|
|
outline = ['新建章节', '扩写章节', '润色章节'],
|
|
|
plot = ['新建剧情', '扩写剧情', '润色剧情'],
|
|
|
draft = ['新建正文', '扩写正文', '润色正文'],
|
|
|
)
|
|
|
|
|
|
prompt_dirname = dict(
|
|
|
outline = 'prompts/创作章节',
|
|
|
plot = 'prompts/创作剧情',
|
|
|
draft = 'prompts/创作正文',
|
|
|
)
|
|
|
|
|
|
|
|
|
PROMPTS = {}
|
|
|
for type_name, dirname in prompt_dirname.items():
|
|
|
PROMPTS[type_name] = {'prompt_names': prompt_names[type_name]}
|
|
|
for name in prompt_names[type_name]:
|
|
|
content = clean_txt_content(load_prompt(dirname, name))
|
|
|
if content.startswith("user:\n"):
|
|
|
content = content[len("user:\n"):]
|
|
|
PROMPTS[type_name][name] = {'content': content}
|
|
|
|
|
|
|
|
|
@app.route('/prompts', methods=['GET'])
|
|
|
def get_prompts():
|
|
|
return jsonify(PROMPTS)
|
|
|
|
|
|
def get_delta_chunks(prev_chunks, curr_chunks):
|
|
|
"""Calculate delta between previous and current chunks"""
|
|
|
if not prev_chunks or len(prev_chunks) != len(curr_chunks):
|
|
|
return "init", curr_chunks
|
|
|
|
|
|
|
|
|
is_delta = True
|
|
|
for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
|
|
|
if len(prev_chunk) != len(curr_chunk):
|
|
|
is_delta = False
|
|
|
break
|
|
|
for prev_str, curr_str in zip(prev_chunk, curr_chunk):
|
|
|
if not curr_str.startswith(prev_str):
|
|
|
is_delta = False
|
|
|
break
|
|
|
if not is_delta:
|
|
|
break
|
|
|
|
|
|
if not is_delta:
|
|
|
return "init", curr_chunks
|
|
|
|
|
|
|
|
|
delta_chunks = []
|
|
|
for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
|
|
|
delta_chunk = []
|
|
|
for prev_str, curr_str in zip(prev_chunk, curr_chunk):
|
|
|
delta_str = curr_str[len(prev_str):]
|
|
|
delta_chunk.append(delta_str)
|
|
|
delta_chunks.append(delta_chunk)
|
|
|
|
|
|
return "delta", delta_chunks
|
|
|
|
|
|
|
|
|
def call_write(writer_mode, chunk_list, global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
|
|
|
if ENABLE_ONLINE_DEMO:
|
|
|
if max_thread_num > MAX_THREAD_NUM:
|
|
|
raise Exception("在线Demo模型下,最大线程数不能超过" + str(MAX_THREAD_NUM) + "!")
|
|
|
|
|
|
|
|
|
chunk_list = [[e.strip() + ('\n' if e.strip() and rowi != len(chunk_list)-1 else '') for e in row] for rowi, row in enumerate(chunk_list)]
|
|
|
|
|
|
prev_chunks = None
|
|
|
def delta_wrapper(chunk_list, done=False, msg=None):
|
|
|
|
|
|
chunk_list = [[e.strip() for e in row] for row in chunk_list]
|
|
|
|
|
|
nonlocal prev_chunks
|
|
|
if prev_chunks is None:
|
|
|
prev_chunks = chunk_list
|
|
|
return {
|
|
|
"done": done,
|
|
|
"chunk_type": "init",
|
|
|
"chunk_list": chunk_list,
|
|
|
"msg": msg
|
|
|
}
|
|
|
else:
|
|
|
chunk_type, new_chunks = get_delta_chunks(prev_chunks, chunk_list)
|
|
|
prev_chunks = chunk_list
|
|
|
return {
|
|
|
"done": done,
|
|
|
"chunk_type": chunk_type,
|
|
|
"chunk_list": new_chunks,
|
|
|
"msg": msg
|
|
|
}
|
|
|
|
|
|
novel_writer = load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
|
|
|
|
|
|
|
|
|
|
|
|
if writer_mode == 'draft':
|
|
|
target_chunk = novel_writer.get_chunk(pair_span=chunk_span)
|
|
|
new_target_chunk = novel_writer.map_text_wo_llm(target_chunk)
|
|
|
novel_writer.apply_chunks([target_chunk], [new_target_chunk])
|
|
|
chunk_span = novel_writer.get_chunk_pair_span(new_target_chunk)
|
|
|
|
|
|
init_novel_writer = load_novel_writer(writer_mode, list(novel_writer.xy_pairs), global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
|
|
|
|
|
|
|
|
|
|
|
|
generator = novel_writer.write(prompt_content, pair_span=chunk_span)
|
|
|
|
|
|
prompt_outputs = []
|
|
|
last_yield_time = time.time()
|
|
|
|
|
|
prompt_name = ''
|
|
|
for kp_msg in generator:
|
|
|
if isinstance(kp_msg, KeyPointMsg):
|
|
|
|
|
|
prompt_name = kp_msg.prompt_name
|
|
|
continue
|
|
|
else:
|
|
|
chunk_list = kp_msg
|
|
|
|
|
|
current_cost = 0
|
|
|
currency_symbol = ''
|
|
|
current_model = ''
|
|
|
data_chunks = []
|
|
|
prompt_outputs.clear()
|
|
|
for e in chunk_list:
|
|
|
if e is None: continue
|
|
|
output, chunk = e
|
|
|
if output is None: continue
|
|
|
prompt_outputs.append(output)
|
|
|
current_text = ""
|
|
|
current_model = output['response_msgs'].model
|
|
|
current_cost += output['response_msgs'].cost
|
|
|
currency_symbol = output['response_msgs'].currency_symbol
|
|
|
if 'plot2text' in output:
|
|
|
current_text += f"正在建立映射关系..." + '\n'
|
|
|
else:
|
|
|
current_text = output['text']
|
|
|
data_chunks.append((chunk.x_chunk, chunk.y_chunk, current_text))
|
|
|
|
|
|
if only_prompt:
|
|
|
yield {'prompts': [e['response_msgs'] for e in prompt_outputs]}
|
|
|
return
|
|
|
|
|
|
current_time = time.time()
|
|
|
if current_time - last_yield_time >= 0.2:
|
|
|
yield delta_wrapper(data_chunks, done=False, msg=f"正在 {prompt_name} ({len(prompt_outputs)} / {len(chunk_list)})" + f" 模型:{current_model} 花费:{current_cost:.5f}{currency_symbol}" if current_model else '')
|
|
|
last_yield_time = current_time
|
|
|
|
|
|
|
|
|
data_chunks = init_novel_writer.diff_to(novel_writer, pair_span=chunk_span)
|
|
|
|
|
|
yield delta_wrapper(data_chunks, done=True, msg='创作完成!')
|
|
|
|
|
|
|
|
|
@app.route('/write', methods=['POST'])
|
|
|
def write():
|
|
|
data = request.json
|
|
|
writer_mode = data['writer_mode']
|
|
|
chunk_list = data['chunk_list']
|
|
|
chunk_span = data['chunk_span']
|
|
|
prompt_content = data['prompt_content']
|
|
|
x_chunk_length = data['x_chunk_length']
|
|
|
y_chunk_length = data['y_chunk_length']
|
|
|
main_model = data['main_model']
|
|
|
sub_model = data['sub_model']
|
|
|
global_context = data['global_context']
|
|
|
only_prompt = data['only_prompt']
|
|
|
|
|
|
|
|
|
if 'settings' in data:
|
|
|
max_thread_num = data['settings']['MAX_THREAD_NUM']
|
|
|
|
|
|
|
|
|
stream_id = str(time.time())
|
|
|
active_streams[stream_id] = True
|
|
|
|
|
|
def generate():
|
|
|
try:
|
|
|
|
|
|
yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
|
|
|
|
|
|
for result in call_write(writer_mode, list(chunk_list), global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
|
|
|
if not active_streams.get(stream_id, False):
|
|
|
|
|
|
print(f"Stream was stopped by client: {stream_id}")
|
|
|
return
|
|
|
|
|
|
yield f"data: {json.dumps(result)}\n\n"
|
|
|
except Exception as e:
|
|
|
error_msg = f"创作出错:\n{str(e)}"
|
|
|
error_chunk_list = [[*e[:2], error_msg] for e in chunk_list[chunk_span[0]:chunk_span[1]]]
|
|
|
|
|
|
error_data = {
|
|
|
"done": True,
|
|
|
"chunk_type": "init",
|
|
|
"chunk_list": error_chunk_list
|
|
|
}
|
|
|
yield f"data: {json.dumps(error_data)}\n\n"
|
|
|
finally:
|
|
|
|
|
|
if stream_id in active_streams:
|
|
|
del active_streams[stream_id]
|
|
|
|
|
|
return Response(generate(), mimetype='text/event-stream')
|
|
|
|
|
|
|
|
|
@app.route('/summary', methods=['POST'])
|
|
|
def process_novel_text():
|
|
|
data = request.json
|
|
|
content = data['content']
|
|
|
novel_name = data['novel_name']
|
|
|
|
|
|
|
|
|
stream_id = str(time.time())
|
|
|
active_streams[stream_id] = True
|
|
|
|
|
|
def generate():
|
|
|
try:
|
|
|
yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
|
|
|
|
|
|
main_model = get_model_config_from_provider_model(data['main_model'])
|
|
|
sub_model = get_model_config_from_provider_model(data['sub_model'])
|
|
|
max_novel_summary_length = data['settings']['MAX_NOVEL_SUMMARY_LENGTH']
|
|
|
max_thread_num = data['settings']['MAX_THREAD_NUM']
|
|
|
last_yield_time = 0
|
|
|
for result in process_novel(content, novel_name, main_model, sub_model, max_novel_summary_length, max_thread_num):
|
|
|
if not active_streams.get(stream_id, False):
|
|
|
|
|
|
print(f"Stream was stopped by client: {stream_id}")
|
|
|
return
|
|
|
|
|
|
current_time = time.time()
|
|
|
yield_value = f"data: {json.dumps(result)}\n\n"
|
|
|
if current_time - last_yield_time >= 0.2:
|
|
|
last_yield_time = current_time
|
|
|
yield yield_value
|
|
|
if current_time - last_yield_time < 0.2:
|
|
|
|
|
|
import yaml
|
|
|
result_dict = json.loads(yield_value.replace('data: ', '').strip())
|
|
|
with open('tmp.yaml', 'w', encoding='utf-8') as f:
|
|
|
yaml.dump(result_dict, f, allow_unicode=True)
|
|
|
|
|
|
yield yield_value
|
|
|
|
|
|
except Exception as e:
|
|
|
error_data = {
|
|
|
"progress_msg": f"处理出错:{str(e)}",
|
|
|
}
|
|
|
yield f"data: {json.dumps(error_data)}\n\n"
|
|
|
finally:
|
|
|
|
|
|
if stream_id in active_streams:
|
|
|
del active_streams[stream_id]
|
|
|
|
|
|
return Response(generate(), mimetype='text/event-stream')
|
|
|
|
|
|
|
|
|
active_streams = {}
|
|
|
|
|
|
@app.route('/stop_stream', methods=['POST'])
|
|
|
def stop_stream():
|
|
|
data = request.json
|
|
|
stream_id = data.get('stream_id')
|
|
|
if stream_id in active_streams:
|
|
|
active_streams[stream_id] = False
|
|
|
return jsonify({'success': True})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
app.run(host=BACKEND_HOST, port=BACKEND_PORT, debug=False) |