novel / AIGN.py
deeme's picture
Upload 5 files
dd81a5f verified
import os
import re
import time
from AIGN_Prompt import *
def Retryer(func, max_retries=10):
def wrapper(*args, **kwargs):
for _ in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
print("-" * 30 + f"\n失败:\n{e}\n" + "-" * 30)
time.sleep(2.333)
raise ValueError("失败")
return wrapper
class MarkdownAgent:
"""专门应对输入输出都是md格式的情况,例如小说生成"""
def __init__(
self,
chatLLM,
sys_prompt: str,
user_prompt: str,
name: str,
temperature=0.8,
top_p=0.8,
use_memory=False,
first_replay="明白了。",
is_speak=True,
) -> None:
self.chatLLM = chatLLM
self.sys_prompt = sys_prompt
self.user_prompt = user_prompt
self.temperature = temperature
self.top_p = top_p
self.use_memory = use_memory
self.is_speak = is_speak
self.history = [{"role": "system", "content": self.sys_prompt}, {"role": "user", "content": self.user_prompt}]
if first_replay:
self.history.append({"role": "assistant", "content": first_replay})
else:
resp = chatLLM(messages=self.history)
self.history.append({"role": "assistant", "content": resp["content"]})
def query(self, user_input: str) -> str:
resp = self.chatLLM(
messages=self.history + [{"role": "user", "content": user_input}],
temperature=self.temperature,
top_p=self.top_p,
)
if self.use_memory:
self.history.append({"role": "user", "content": user_input})
self.history.append({"role": "assistant", "content": resp["content"]})
return resp
def getOutput(self, input_content: str, output_keys: list) -> dict:
"""解析类md格式中 # key 的内容"""
resp = self.query(input_content)
output = resp["content"]
lines = output.split("\n")
sections = self.parse_sections1(lines, output_keys)
# 检查是否所有需要的键都存在
for k in output_keys:
if (k not in sections) or (len(sections[k]) == 0):
# 单独对k进行重新parse_sections2,此时查找##,并更新sections
section_content = self.parse_sections2(lines, k)
if section_content:
sections[k] = section_content
else:
raise ValueError(f"fail to parse {k} in output:\n")
return sections
def parse_sections1(self, lines, output_keys):
sections = {key: "" for key in output_keys}
current_section = ""
for line in lines:
if line.startswith("# ") or line.startswith(" # "):
# new key
current_section = line[2:].strip()
sections[current_section] = []
else:
# add content to current key
if current_section:
sections[current_section].append(line.strip())
for key in sections.keys():
sections[key] = "\n".join(sections[key]).strip()
return sections
def parse_sections2(self, lines, k):
content = []
capturing = False
for line in lines:
stripped_line = line.strip()
if stripped_line.startswith(("##", " ##", "###", " ###")) and k.lower() in stripped_line.lower():
capturing = True
continue
elif stripped_line.startswith(("##", " ##", "###", " ###")) and capturing:
break
if capturing:
content.append(stripped_line)
return "\n".join(content).strip()
def invoke(self, inputs: dict, output_keys: list) -> dict:
input_content = ""
for k, v in inputs.items():
if isinstance(v, str) and len(v) > 0:
input_content += f"# {k}\n{v}\n\n"
result = Retryer(self.getOutput)(input_content, output_keys)
return result
def clear_memory(self):
if self.use_memory:
self.history = self.history[:2]
class AIGN:
def __init__(self, chatLLM):
self.chatLLM = chatLLM
self.novel_outline = ""
self.paragraph_list = []
self.novel_content = ""
self.writing_plan = ""
self.temp_setting = ""
self.writing_memory = ""
self.no_memory_paragraph = ""
self.user_idea = ""
self.user_requriments = ""
self.history_states = [] # 用于存储历史状态
self.chapter_list = [] # 用于存储章节列表
self.novel_outline_writer = MarkdownAgent(
chatLLM=self.chatLLM,
sys_prompt=system_prompt,
user_prompt=novel_outline_writer_prompt,
name="NovelOutlineWriter",
temperature=0.98,
)
self.novel_writer = MarkdownAgent(
chatLLM=self.chatLLM,
sys_prompt=system_prompt,
user_prompt=novel_writer_prompt,
name="NovelWriter",
temperature=0.81,
)
self.memory_maker = MarkdownAgent(
chatLLM=self.chatLLM,
sys_prompt=system_prompt,
user_prompt=memory_maker_prompt,
name="MemoryMaker",
temperature=0.66,
)
def split_chapters(self, novel_content):
# 使用正则表达式匹配章节标题
chapter_pattern = re.compile(r'(?:##?|)?\s*第([一二三四五六七八九十百千万亿\d]+)章[::]?\s*(.+)')
# 将小说正文按章节标题分割
chapters = chapter_pattern.split(novel_content)
# 移除第一个空字符串(如果存在)
if chapters[0] == '':
chapters = chapters[1:]
# 将章节标题和内容组合成元组
chapter_tuples = []
for i in range(0, len(chapters), 3):
if i + 2 < len(chapters):
chapter_num = chapters[i]
chapter_title = chapters[i + 1]
chapter_content = chapters[i + 2]
chapter_tuples.append((f"第{chapter_num}{chapter_title}", chapter_content))
return chapter_tuples
def update_chapter_list(self):
self.chapter_list = self.split_chapters(self.novel_content)
def updateNovelContent(self):
self.novel_content = ""
for paragraph in self.paragraph_list:
self.novel_content += f"{paragraph}\n\n"
self.update_chapter_list()
return self.novel_content
def genNovelOutline(self, user_idea=None):
if user_idea:
self.user_idea = user_idea
resp = self.novel_outline_writer.invoke(
inputs={"用户想法": self.user_idea},
output_keys=["大纲"],
)
self.novel_outline = resp["大纲"]
return self.novel_outline
def genBeginning(self, user_requriments=None):
if user_requriments:
self.user_requriments = user_requriments
resp = self.novel_beginning_writer.invoke(
inputs={
"用户想法": self.user_idea,
"小说大纲": self.novel_outline,
"用户要求": self.user_requriments,
},
output_keys=["开头", "计划", "临时设定"],
)
beginning = resp["开头"]
self.writing_plan = resp["计划"]
self.temp_setting = resp["临时设定"]
self.paragraph_list.append(beginning)
self.updateNovelContent()
self.update_chapter_list()
return beginning
def getLastParagraph(self, max_length=2000):
last_paragraph = ""
for i in range(0, len(self.paragraph_list)):
if (len(last_paragraph) + len(self.paragraph_list[-1 - i])) < max_length:
last_paragraph = self.paragraph_list[-1 - i] + "\n" + last_paragraph
else:
break
return last_paragraph
def recordNovel(self):
record_content = ""
record_content += f"# 大纲\n\n{self.novel_outline}\n\n"
record_content += f"# 正文\n\n"
record_content += self.novel_content
record_content += f"# 记忆\n\n{self.writing_memory}\n\n"
record_content += f"# 计划\n\n{self.writing_plan}\n\n"
record_content += f"# 临时设定\n\n{self.temp_setting}\n\n"
with open("novel_record.md", "w", encoding="utf-8") as f:
f.write(record_content)
def updateMemory(self):
if (len(self.no_memory_paragraph)) > 2000:
resp = self.memory_maker.invoke(
inputs={
"前文记忆": self.writing_memory,
"正文内容": self.no_memory_paragraph,
},
output_keys=["新的记忆"],
)
self.writing_memory = resp["新的记忆"]
self.no_memory_paragraph = ""
def save_state(self):
state = {
"novel_outline": self.novel_outline,
"paragraph_list": self.paragraph_list,
"novel_content": self.novel_content,
"writing_plan": self.writing_plan,
"temp_setting": self.temp_setting,
"writing_memory": self.writing_memory
}
self.history_states.append(state)
def undo(self):
if self.history_states:
previous_state = self.history_states.pop()
self.novel_outline = previous_state["novel_outline"]
self.paragraph_list = previous_state["paragraph_list"]
self.novel_content = previous_state["novel_content"]
self.writing_plan = previous_state["writing_plan"]
self.temp_setting = previous_state["temp_setting"]
self.writing_memory = previous_state["writing_memory"]
return True
return False
def genNextParagraph(self, user_requriments=None):
self.save_state() # 保存当前状态
if user_requriments:
self.user_requriments = user_requriments
resp = self.novel_writer.invoke(
inputs={
"用户想法": self.user_idea,
"大纲": self.novel_outline,
"前文记忆": self.writing_memory,
"临时设定": self.temp_setting,
"计划": self.writing_plan,
"用户要求": self.user_requriments,
"上文内容": self.getLastParagraph(),
},
output_keys=["段落", "计划", "临时设定"],
)
next_paragraph = resp["段落"]
next_writing_plan = resp["计划"]
next_temp_setting = resp["临时设定"]
self.paragraph_list.append(next_paragraph)
self.writing_plan = next_writing_plan
self.temp_setting = next_temp_setting
self.no_memory_paragraph += f"\n{next_paragraph}"
self.updateMemory()
self.updateNovelContent()
self.recordNovel()
self.update_chapter_list()
return next_paragraph