long / app.py
deeme's picture
Upload 111 files
217acfe verified
import json
import time
from flask import Flask, request, Response, jsonify
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts.baseprompt import clean_txt_content, load_prompt
from core.writer_utils import KeyPointMsg
from core.draft_writer import DraftWriter
from core.plot_writer import PlotWriter
from core.outline_writer import OutlineWriter
from setting import setting_bp
from summary import process_novel
from backend_utils import get_model_config_from_provider_model
from config import MAX_NOVEL_SUMMARY_LENGTH, MAX_THREAD_NUM, ENABLE_ONLINE_DEMO
app.register_blueprint(setting_bp)
# 添加配置
BACKEND_HOST = os.environ.get('BACKEND_HOST', '0.0.0.0')
BACKEND_PORT = int(os.environ.get('BACKEND_PORT', 7869))
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
'status': 'healthy',
'timestamp': int(time.time())
}), 200
def load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num) -> DraftWriter:
kwargs = dict(
xy_pairs=chunk_list,
model=get_model_config_from_provider_model(main_model),
sub_model=get_model_config_from_provider_model(sub_model),
)
kwargs['x_chunk_length'] = x_chunk_length
kwargs['y_chunk_length'] = y_chunk_length
kwargs['max_thread_num'] = max_thread_num
match writer_mode:
case 'draft':
kwargs['global_context'] = {}
novel_writer = DraftWriter(**kwargs)
case 'outline':
kwargs['global_context'] = {'summary': global_context}
novel_writer = OutlineWriter(**kwargs)
case 'plot':
kwargs['global_context'] = {'chapter': global_context}
novel_writer = PlotWriter(**kwargs)
case _:
raise ValueError(f"unknown writer: {writer_mode}")
return novel_writer
prompt_names = dict(
outline = ['新建章节', '扩写章节', '润色章节'],
plot = ['新建剧情', '扩写剧情', '润色剧情'],
draft = ['新建正文', '扩写正文', '润色正文'],
)
prompt_dirname = dict(
outline = 'prompts/创作章节',
plot = 'prompts/创作剧情',
draft = 'prompts/创作正文',
)
PROMPTS = {}
for type_name, dirname in prompt_dirname.items():
PROMPTS[type_name] = {'prompt_names': prompt_names[type_name]}
for name in prompt_names[type_name]:
content = clean_txt_content(load_prompt(dirname, name))
if content.startswith("user:\n"):
content = content[len("user:\n"):]
PROMPTS[type_name][name] = {'content': content}
@app.route('/prompts', methods=['GET'])
def get_prompts():
return jsonify(PROMPTS)
def get_delta_chunks(prev_chunks, curr_chunks):
"""Calculate delta between previous and current chunks"""
if not prev_chunks or len(prev_chunks) != len(curr_chunks):
return "init", curr_chunks
# Check if all strings in current chunks start with their corresponding previous strings
is_delta = True
for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
if len(prev_chunk) != len(curr_chunk):
is_delta = False
break
for prev_str, curr_str in zip(prev_chunk, curr_chunk):
if not curr_str.startswith(prev_str):
is_delta = False
break
if not is_delta:
break
if not is_delta:
return "init", curr_chunks
# Calculate deltas
delta_chunks = []
for prev_chunk, curr_chunk in zip(prev_chunks, curr_chunks):
delta_chunk = []
for prev_str, curr_str in zip(prev_chunk, curr_chunk):
delta_str = curr_str[len(prev_str):]
delta_chunk.append(delta_str)
delta_chunks.append(delta_chunk)
return "delta", delta_chunks
def call_write(writer_mode, chunk_list, global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
if ENABLE_ONLINE_DEMO:
if max_thread_num > MAX_THREAD_NUM:
raise Exception("在线Demo模型下,最大线程数不能超过" + str(MAX_THREAD_NUM) + "!")
# 输入的chunk_list中每个chunk需要加上换行,除了最后一个chunk(因为是从页面中各个chunk传来的)
chunk_list = [[e.strip() + ('\n' if e.strip() and rowi != len(chunk_list)-1 else '') for e in row] for rowi, row in enumerate(chunk_list)]
prev_chunks = None
def delta_wrapper(chunk_list, done=False, msg=None):
# 返回的chunk_list中每个chunk需要去掉换行
chunk_list = [[e.strip() for e in row] for row in chunk_list]
nonlocal prev_chunks
if prev_chunks is None:
prev_chunks = chunk_list
return {
"done": done,
"chunk_type": "init",
"chunk_list": chunk_list,
"msg": msg
}
else:
chunk_type, new_chunks = get_delta_chunks(prev_chunks, chunk_list)
prev_chunks = chunk_list
return {
"done": done,
"chunk_type": chunk_type,
"chunk_list": new_chunks,
"msg": msg
}
novel_writer = load_novel_writer(writer_mode, chunk_list, global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
# draft需要映射,所以进行初始划分
if writer_mode == 'draft':
target_chunk = novel_writer.get_chunk(pair_span=chunk_span)
new_target_chunk = novel_writer.map_text_wo_llm(target_chunk)
novel_writer.apply_chunks([target_chunk], [new_target_chunk])
chunk_span = novel_writer.get_chunk_pair_span(new_target_chunk)
init_novel_writer = load_novel_writer(writer_mode, list(novel_writer.xy_pairs), global_context, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num)
# TODO: writer.write 应该保证无论什么prompt,都能够同时适应y为空和y有值地情况
# 换句话说,就是虽然可以单列出一个"新建正文",但用扩写正文也能实现同样的效果。
generator = novel_writer.write(prompt_content, pair_span=chunk_span)
prompt_outputs = []
last_yield_time = time.time() # Initialize the last yield time
prompt_name = ''
for kp_msg in generator:
if isinstance(kp_msg, KeyPointMsg):
# 如果要支持关键节点保存,需要计算一个编辑上的更改,然后在这里yield writer
prompt_name = kp_msg.prompt_name
continue
else:
chunk_list = kp_msg
current_cost = 0
currency_symbol = ''
current_model = ''
data_chunks = []
prompt_outputs.clear()
for e in chunk_list:
if e is None: continue # e为None说明该chunk还未处理
output, chunk = e
if output is None: continue # output为None说明该chunk未yield就return,说明未调用llm
prompt_outputs.append(output)
current_text = ""
current_model = output['response_msgs'].model
current_cost += output['response_msgs'].cost
currency_symbol = output['response_msgs'].currency_symbol
if 'plot2text' in output:
current_text += f"正在建立映射关系..." + '\n'
else:
current_text = output['text']
data_chunks.append((chunk.x_chunk, chunk.y_chunk, current_text))
if only_prompt:
yield {'prompts': [e['response_msgs'] for e in prompt_outputs]}
return
current_time = time.time()
if current_time - last_yield_time >= 0.2: # Check if 0.2 seconds have passed
yield delta_wrapper(data_chunks, done=False, msg=f"正在 {prompt_name}{len(prompt_outputs)} / {len(chunk_list)})" + f" 模型:{current_model} 花费:{current_cost:.5f}{currency_symbol}" if current_model else '')
last_yield_time = current_time # Update the last yield time
# 这里是计算出一个编辑上的更改,方便前端显示,后续diff功能将不由writer提供,因为这是为了显示的要求
data_chunks = init_novel_writer.diff_to(novel_writer, pair_span=chunk_span)
yield delta_wrapper(data_chunks, done=True, msg='创作完成!')
@app.route('/write', methods=['POST'])
def write():
data = request.json
writer_mode = data['writer_mode']
chunk_list = data['chunk_list']
chunk_span = data['chunk_span']
prompt_content = data['prompt_content']
x_chunk_length = data['x_chunk_length']
y_chunk_length = data['y_chunk_length']
main_model = data['main_model']
sub_model = data['sub_model']
global_context = data['global_context']
only_prompt = data['only_prompt']
# Update settings if provided
if 'settings' in data:
max_thread_num = data['settings']['MAX_THREAD_NUM']
# Generate unique stream ID
stream_id = str(time.time())
active_streams[stream_id] = True
def generate():
try:
# Send stream ID to client
yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
for result in call_write(writer_mode, list(chunk_list), global_context, chunk_span, prompt_content, x_chunk_length, y_chunk_length, main_model, sub_model, max_thread_num, only_prompt):
if not active_streams.get(stream_id, False):
# Stream was stopped by client
print(f"Stream was stopped by client: {stream_id}")
return
yield f"data: {json.dumps(result)}\n\n"
except Exception as e:
error_msg = f"创作出错:\n{str(e)}"
error_chunk_list = [[*e[:2], error_msg] for e in chunk_list[chunk_span[0]:chunk_span[1]]]
error_data = {
"done": True,
"chunk_type": "init",
"chunk_list": error_chunk_list
}
yield f"data: {json.dumps(error_data)}\n\n"
finally:
# Clean up stream tracking
if stream_id in active_streams:
del active_streams[stream_id]
return Response(generate(), mimetype='text/event-stream')
@app.route('/summary', methods=['POST'])
def process_novel_text():
data = request.json
content = data['content']
novel_name = data['novel_name']
# Generate unique stream ID
stream_id = str(time.time())
active_streams[stream_id] = True
def generate():
try:
yield f"data: {json.dumps({'stream_id': stream_id})}\n\n"
main_model = get_model_config_from_provider_model(data['main_model'])
sub_model = get_model_config_from_provider_model(data['sub_model'])
max_novel_summary_length = data['settings']['MAX_NOVEL_SUMMARY_LENGTH']
max_thread_num = data['settings']['MAX_THREAD_NUM']
last_yield_time = 0
for result in process_novel(content, novel_name, main_model, sub_model, max_novel_summary_length, max_thread_num):
if not active_streams.get(stream_id, False):
# Stream was stopped by client
print(f"Stream was stopped by client: {stream_id}")
return
current_time = time.time()
yield_value = f"data: {json.dumps(result)}\n\n"
if current_time - last_yield_time >= 0.2:
last_yield_time = current_time
yield yield_value
if current_time - last_yield_time < 0.2:
# Save last yield to yaml file
import yaml
result_dict = json.loads(yield_value.replace('data: ', '').strip())
with open('tmp.yaml', 'w', encoding='utf-8') as f:
yaml.dump(result_dict, f, allow_unicode=True)
yield yield_value # Ensure last yield is returned
except Exception as e:
error_data = {
"progress_msg": f"处理出错:{str(e)}",
}
yield f"data: {json.dumps(error_data)}\n\n"
finally:
# Clean up stream tracking
if stream_id in active_streams:
del active_streams[stream_id]
return Response(generate(), mimetype='text/event-stream')
# Dictionary to track active streams
active_streams = {}
@app.route('/stop_stream', methods=['POST'])
def stop_stream():
data = request.json
stream_id = data.get('stream_id')
if stream_id in active_streams:
active_streams[stream_id] = False
return jsonify({'success': True})
if __name__ == '__main__':
app.run(host=BACKEND_HOST, port=BACKEND_PORT, debug=False)