FAQ-workshop / blog_class.py
yeelou's picture
add cai ren speaker
05d04ce
import time
import numpy as np
from api import PKL_FILE, generate_item, get_embedding, load_pkl, save_pkl
from get_blog import get_new_blog_content, get_old_blog_content
from prompts import (
Common_text,
Creative_text,
Full_text,
QA_chat_Prompt_matsu_template,
QA_chat_Prompt_other_template,
QA_Prompt_matsu_template,
QA_Prompt_other_template,
REWRITE_Prompt,
REWRITE_SYS_Prompt,
Short_text,
SYS_Prompt,
ken_style,
mana_style,
)
def max_cosine_similarity(v1, v2_list):
""" """
v1 = np.array(v1)
v2_list = np.array(v2_list)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2_list, axis=1)
if norm_v1 == 0:
return 1.0 if np.any(norm_v2 == 0) else 0.0
valid_indices = norm_v2 > 0
similarities = np.full(v2_list.shape[0], -1.0)
if np.any(valid_indices):
similarities[valid_indices] = np.dot(v2_list[valid_indices], v1) / (
norm_v1 * norm_v2[valid_indices]
)
return np.max(similarities)
def save_feedback(value, liked):
if liked:
md_text = "text:\n" + value + "\n" + "liked"
else:
md_text = "text:\n" + value + "\n" + "disliked"
timestamp = int(time.time() * 1000)
md_filename = f"./resource/feedback_Text_{timestamp}.md"
with open(md_filename, "w", encoding="utf-8") as file:
file.write(md_text)
def save_md_and_get_audio(user_prompt, answer_text):
timestamp = int(time.time() * 1000)
md_text = user_prompt + "\n 応答: \n" + answer_text
md_filename = f"./resource/QA_Text_{timestamp}.md"
audio_filename = f"./resource/QA_Audio_{timestamp}.mp3"
with open(md_filename, "w", encoding="utf-8") as file:
file.write(md_text)
return audio_filename
class knowledge_class:
def __init__(self):
self.knowledge_data = load_pkl(PKL_FILE)
# print(self.knowledge_data)
self.reference_dict = self.get_reference_dict()
# q_v = self.knowledge_data["2024-10-09-プロジェクト計画で重要視すること"][
# "vector"
# ][0]
# t_v = self.knowledge_data["2024-10-09-プロジェクト計画で重要視すること"][
# "vector"
# ]
# cos_sim = max_cosine_similarity(q_v, t_v)
# print(cos_sim)
self.other_speaker = {"ken": ken_style, "mana": mana_style}
def get_reference_dict(self):
reference_dict = {}
for ref_name, ref_dict in self.knowledge_data.items():
title = ref_dict["title"]
if ref_name not in reference_dict:
reference_dict[ref_name] = {}
reference_dict[ref_name]["original_text"] = (
f"### {title}\n"
+ ref_dict["text"]
+ f"\n\n[URL]({ref_dict['url']})"
)
reference_dict[ref_name]["summary"] = ref_dict["summary"]
reference_dict[ref_name]["audio"] = ref_dict["audio"]
else:
print(f"overlap ref_name: {ref_name}")
raise
return reference_dict
def get_new_knowledge(self):
self.knowledge_data = get_old_blog_content(self.knowledge_data)
self.knowledge_data = get_new_blog_content(self.knowledge_data)
save_pkl(PKL_FILE, self.knowledge_data)
self.knowledge_data = load_pkl(PKL_FILE)
self.reference_dict = self.get_reference_dict()
print("PKLファイルの更新が完了しました。")
def find_top_info(self, question_vector, speaker_flag):
results = []
for idx, k_dict in self.knowledge_data.items():
idx_vector = k_dict.get("vector", [0])
cos_sim = max_cosine_similarity(question_vector, idx_vector)
results.append((idx, cos_sim))
results_sorted = sorted(results, key=lambda x: x[1], reverse=True)
top2 = []
retrieve_text = "\n"
retrieve_title = ""
for res_sort in results_sorted:
if res_sort[0] not in top2:
top2.append(res_sort[0])
retrieve_text += f"情報 {len(top2)} : \n"
retrieve_text += (
f"- タイトル:{self.knowledge_data[res_sort[0]]['title']} \n"
)
retrieve_title += f"{len(top2)}. {res_sort[0]} \n"
retrieve_text += (
f"- コンテンツ:\n {self.knowledge_data[res_sort[0]]['text']} \n"
)
if speaker_flag in ["matsu"]:
retrieve_text += f"- 作者の文体と思考パターン:\n {self.knowledge_data[res_sort[0]]['style']} \n"
retrieve_text += f"- 質問と類似度:{res_sort[1]} \n"
retrieve_text += "\n"
if len(top2) > 1:
break
return retrieve_text, retrieve_title
def get_answer(
self, question_text, creative_prompt, full_prompt, temperature, speaker_flag
):
# get similar info
question_vector = get_embedding(question_text)[0]
info_text, info_title = self.find_top_info(question_vector, speaker_flag)
if speaker_flag in ["matsu", "cai", "ren"]:
user_prompt = QA_Prompt_matsu_template.format(
q_text=question_text,
r_text=info_text,
c_text=creative_prompt,
s_text=full_prompt,
)
else:
user_prompt = QA_Prompt_other_template.format(
q_text=question_text,
r_text=info_text,
w_text=self.other_speaker[speaker_flag],
c_text=creative_prompt,
s_text=full_prompt,
)
print("user_prompt:", user_prompt)
answer_text = generate_item(
user_prompt, SYS_Prompt, model="gpt-4.1", temperature=temperature
)
audio_filename = save_md_and_get_audio(user_prompt, answer_text)
return answer_text, info_title, audio_filename
def get_chat_answer(self, chat_list, creative_flag, full_flag, speaker_flag):
print("all chat list:", chat_list)
print("creative_flag:", creative_flag)
print("full_flag:", full_flag)
print("speaker_flag:", speaker_flag)
# creative_prompt = Common_text
# temperature = 1.0
# if creative_flag:
# creative_prompt = Creative_text
# temperature = 1.2
# full_prompt = Short_text
# if full_flag:
# full_prompt = Full_text
creative_prompt = Creative_text if creative_flag else Common_text
temperature = 1.2 if creative_flag else 1.0
full_prompt = Full_text if full_flag else Short_text
chat_history = [
turn
for turn in chat_list
if not (turn["role"] == "assistant" and turn.get("metadata"))
][-5:]
if (
not chat_history
or chat_history[-1]["role"] != "user"
or not chat_history[-1]["content"]
):
return None, None, None
if len(chat_history) == 1:
return self.get_answer(
chat_history[-1]["content"],
creative_prompt,
full_prompt,
temperature,
speaker_flag,
)
# get history
chat_history_str = ""
chat_history_str = "\n".join(
[f"{msg['role']}: {msg['content']}" for msg in chat_history[:-1]]
)
print("chat_history_str:", chat_history_str)
new_query = chat_history[-1]["content"]
print("new_query:", new_query)
rw_prompt = REWRITE_Prompt.format(
chat_history=chat_history_str, question=new_query
)
rewrite_question = generate_item(rw_prompt, REWRITE_SYS_Prompt, model="gpt-4.1")
print("rewrite_question:", rewrite_question)
# get rewrite question
# get similar info
question_vector = get_embedding(rewrite_question)[0]
info_text, info_title = self.find_top_info(question_vector, speaker_flag)
if speaker_flag in ["matsu", "cai", "ren"]:
user_prompt = QA_chat_Prompt_matsu_template.format(
h_text=chat_history_str,
q_text=rewrite_question,
r_text=info_text,
c_text=creative_prompt,
s_text=full_prompt,
)
else:
user_prompt = QA_chat_Prompt_other_template.format(
h_text=chat_history_str,
q_text=rewrite_question,
r_text=info_text,
w_text=self.other_speaker[speaker_flag],
c_text=creative_prompt,
s_text=full_prompt,
)
print("user_prompt:", user_prompt)
answer_text = generate_item(
user_prompt,
SYS_Prompt,
model="gpt-4.1",
temperature=temperature,
)
audio_filename = save_md_and_get_audio(user_prompt, answer_text)
return answer_text, info_title, audio_filename
# kc_class = knowledge_class()
# print(kc_class.get_new_knowledge())