Spaces:
Sleeping
Sleeping
| import time | |
| import numpy as np | |
| from api import PKL_FILE, generate_item, get_embedding, load_pkl, save_pkl | |
| from get_blog import get_new_blog_content, get_old_blog_content | |
| from prompts import ( | |
| Common_text, | |
| Creative_text, | |
| Full_text, | |
| QA_chat_Prompt_matsu_template, | |
| QA_chat_Prompt_other_template, | |
| QA_Prompt_matsu_template, | |
| QA_Prompt_other_template, | |
| REWRITE_Prompt, | |
| REWRITE_SYS_Prompt, | |
| Short_text, | |
| SYS_Prompt, | |
| ken_style, | |
| mana_style, | |
| ) | |
| def max_cosine_similarity(v1, v2_list): | |
| """ """ | |
| v1 = np.array(v1) | |
| v2_list = np.array(v2_list) | |
| norm_v1 = np.linalg.norm(v1) | |
| norm_v2 = np.linalg.norm(v2_list, axis=1) | |
| if norm_v1 == 0: | |
| return 1.0 if np.any(norm_v2 == 0) else 0.0 | |
| valid_indices = norm_v2 > 0 | |
| similarities = np.full(v2_list.shape[0], -1.0) | |
| if np.any(valid_indices): | |
| similarities[valid_indices] = np.dot(v2_list[valid_indices], v1) / ( | |
| norm_v1 * norm_v2[valid_indices] | |
| ) | |
| return np.max(similarities) | |
| def save_feedback(value, liked): | |
| if liked: | |
| md_text = "text:\n" + value + "\n" + "liked" | |
| else: | |
| md_text = "text:\n" + value + "\n" + "disliked" | |
| timestamp = int(time.time() * 1000) | |
| md_filename = f"./resource/feedback_Text_{timestamp}.md" | |
| with open(md_filename, "w", encoding="utf-8") as file: | |
| file.write(md_text) | |
| def save_md_and_get_audio(user_prompt, answer_text): | |
| timestamp = int(time.time() * 1000) | |
| md_text = user_prompt + "\n 応答: \n" + answer_text | |
| md_filename = f"./resource/QA_Text_{timestamp}.md" | |
| audio_filename = f"./resource/QA_Audio_{timestamp}.mp3" | |
| with open(md_filename, "w", encoding="utf-8") as file: | |
| file.write(md_text) | |
| return audio_filename | |
| class knowledge_class: | |
| def __init__(self): | |
| self.knowledge_data = load_pkl(PKL_FILE) | |
| # print(self.knowledge_data) | |
| self.reference_dict = self.get_reference_dict() | |
| # q_v = self.knowledge_data["2024-10-09-プロジェクト計画で重要視すること"][ | |
| # "vector" | |
| # ][0] | |
| # t_v = self.knowledge_data["2024-10-09-プロジェクト計画で重要視すること"][ | |
| # "vector" | |
| # ] | |
| # cos_sim = max_cosine_similarity(q_v, t_v) | |
| # print(cos_sim) | |
| self.other_speaker = {"ken": ken_style, "mana": mana_style} | |
| def get_reference_dict(self): | |
| reference_dict = {} | |
| for ref_name, ref_dict in self.knowledge_data.items(): | |
| title = ref_dict["title"] | |
| if ref_name not in reference_dict: | |
| reference_dict[ref_name] = {} | |
| reference_dict[ref_name]["original_text"] = ( | |
| f"### {title}\n" | |
| + ref_dict["text"] | |
| + f"\n\n[URL]({ref_dict['url']})" | |
| ) | |
| reference_dict[ref_name]["summary"] = ref_dict["summary"] | |
| reference_dict[ref_name]["audio"] = ref_dict["audio"] | |
| else: | |
| print(f"overlap ref_name: {ref_name}") | |
| raise | |
| return reference_dict | |
| def get_new_knowledge(self): | |
| self.knowledge_data = get_old_blog_content(self.knowledge_data) | |
| self.knowledge_data = get_new_blog_content(self.knowledge_data) | |
| save_pkl(PKL_FILE, self.knowledge_data) | |
| self.knowledge_data = load_pkl(PKL_FILE) | |
| self.reference_dict = self.get_reference_dict() | |
| print("PKLファイルの更新が完了しました。") | |
| def find_top_info(self, question_vector, speaker_flag): | |
| results = [] | |
| for idx, k_dict in self.knowledge_data.items(): | |
| idx_vector = k_dict.get("vector", [0]) | |
| cos_sim = max_cosine_similarity(question_vector, idx_vector) | |
| results.append((idx, cos_sim)) | |
| results_sorted = sorted(results, key=lambda x: x[1], reverse=True) | |
| top2 = [] | |
| retrieve_text = "\n" | |
| retrieve_title = "" | |
| for res_sort in results_sorted: | |
| if res_sort[0] not in top2: | |
| top2.append(res_sort[0]) | |
| retrieve_text += f"情報 {len(top2)} : \n" | |
| retrieve_text += ( | |
| f"- タイトル:{self.knowledge_data[res_sort[0]]['title']} \n" | |
| ) | |
| retrieve_title += f"{len(top2)}. {res_sort[0]} \n" | |
| retrieve_text += ( | |
| f"- コンテンツ:\n {self.knowledge_data[res_sort[0]]['text']} \n" | |
| ) | |
| if speaker_flag in ["matsu"]: | |
| retrieve_text += f"- 作者の文体と思考パターン:\n {self.knowledge_data[res_sort[0]]['style']} \n" | |
| retrieve_text += f"- 質問と類似度:{res_sort[1]} \n" | |
| retrieve_text += "\n" | |
| if len(top2) > 1: | |
| break | |
| return retrieve_text, retrieve_title | |
| def get_answer( | |
| self, question_text, creative_prompt, full_prompt, temperature, speaker_flag | |
| ): | |
| # get similar info | |
| question_vector = get_embedding(question_text)[0] | |
| info_text, info_title = self.find_top_info(question_vector, speaker_flag) | |
| if speaker_flag in ["matsu", "cai", "ren"]: | |
| user_prompt = QA_Prompt_matsu_template.format( | |
| q_text=question_text, | |
| r_text=info_text, | |
| c_text=creative_prompt, | |
| s_text=full_prompt, | |
| ) | |
| else: | |
| user_prompt = QA_Prompt_other_template.format( | |
| q_text=question_text, | |
| r_text=info_text, | |
| w_text=self.other_speaker[speaker_flag], | |
| c_text=creative_prompt, | |
| s_text=full_prompt, | |
| ) | |
| print("user_prompt:", user_prompt) | |
| answer_text = generate_item( | |
| user_prompt, SYS_Prompt, model="gpt-4.1", temperature=temperature | |
| ) | |
| audio_filename = save_md_and_get_audio(user_prompt, answer_text) | |
| return answer_text, info_title, audio_filename | |
| def get_chat_answer(self, chat_list, creative_flag, full_flag, speaker_flag): | |
| print("all chat list:", chat_list) | |
| print("creative_flag:", creative_flag) | |
| print("full_flag:", full_flag) | |
| print("speaker_flag:", speaker_flag) | |
| # creative_prompt = Common_text | |
| # temperature = 1.0 | |
| # if creative_flag: | |
| # creative_prompt = Creative_text | |
| # temperature = 1.2 | |
| # full_prompt = Short_text | |
| # if full_flag: | |
| # full_prompt = Full_text | |
| creative_prompt = Creative_text if creative_flag else Common_text | |
| temperature = 1.2 if creative_flag else 1.0 | |
| full_prompt = Full_text if full_flag else Short_text | |
| chat_history = [ | |
| turn | |
| for turn in chat_list | |
| if not (turn["role"] == "assistant" and turn.get("metadata")) | |
| ][-5:] | |
| if ( | |
| not chat_history | |
| or chat_history[-1]["role"] != "user" | |
| or not chat_history[-1]["content"] | |
| ): | |
| return None, None, None | |
| if len(chat_history) == 1: | |
| return self.get_answer( | |
| chat_history[-1]["content"], | |
| creative_prompt, | |
| full_prompt, | |
| temperature, | |
| speaker_flag, | |
| ) | |
| # get history | |
| chat_history_str = "" | |
| chat_history_str = "\n".join( | |
| [f"{msg['role']}: {msg['content']}" for msg in chat_history[:-1]] | |
| ) | |
| print("chat_history_str:", chat_history_str) | |
| new_query = chat_history[-1]["content"] | |
| print("new_query:", new_query) | |
| rw_prompt = REWRITE_Prompt.format( | |
| chat_history=chat_history_str, question=new_query | |
| ) | |
| rewrite_question = generate_item(rw_prompt, REWRITE_SYS_Prompt, model="gpt-4.1") | |
| print("rewrite_question:", rewrite_question) | |
| # get rewrite question | |
| # get similar info | |
| question_vector = get_embedding(rewrite_question)[0] | |
| info_text, info_title = self.find_top_info(question_vector, speaker_flag) | |
| if speaker_flag in ["matsu", "cai", "ren"]: | |
| user_prompt = QA_chat_Prompt_matsu_template.format( | |
| h_text=chat_history_str, | |
| q_text=rewrite_question, | |
| r_text=info_text, | |
| c_text=creative_prompt, | |
| s_text=full_prompt, | |
| ) | |
| else: | |
| user_prompt = QA_chat_Prompt_other_template.format( | |
| h_text=chat_history_str, | |
| q_text=rewrite_question, | |
| r_text=info_text, | |
| w_text=self.other_speaker[speaker_flag], | |
| c_text=creative_prompt, | |
| s_text=full_prompt, | |
| ) | |
| print("user_prompt:", user_prompt) | |
| answer_text = generate_item( | |
| user_prompt, | |
| SYS_Prompt, | |
| model="gpt-4.1", | |
| temperature=temperature, | |
| ) | |
| audio_filename = save_md_and_get_audio(user_prompt, answer_text) | |
| return answer_text, info_title, audio_filename | |
| # kc_class = knowledge_class() | |
| # print(kc_class.get_new_knowledge()) | |