Spaces:
Sleeping
Sleeping
| import time | |
| import numpy as np | |
| from api import PKL_FILE, generate_item, get_embedding, load_pkl, save_pkl | |
| from get_blog import get_new_blog_content, get_old_blog_content | |
| from prompts import ( | |
| Common_text, | |
| Creative_text, | |
| Full_text, | |
| QA_chat_Prompt_template, | |
| QA_Prompt_template, | |
| REWRITE_Prompt, | |
| REWRITE_SYS_Prompt, | |
| Short_text, | |
| SYS_Prompt, | |
| ) | |
| def max_cosine_similarity(v1, v2_list): | |
| """ """ | |
| v1 = np.array(v1) | |
| v2_list = np.array(v2_list) | |
| norm_v1 = np.linalg.norm(v1) | |
| norm_v2 = np.linalg.norm(v2_list, axis=1) | |
| if norm_v1 == 0: | |
| return 1.0 if np.any(norm_v2 == 0) else 0.0 | |
| valid_indices = norm_v2 > 0 | |
| similarities = np.full(v2_list.shape[0], -1.0) | |
| if np.any(valid_indices): | |
| similarities[valid_indices] = np.dot(v2_list[valid_indices], v1) / ( | |
| norm_v1 * norm_v2[valid_indices] | |
| ) | |
| return np.max(similarities) | |
| def save_feedback(value, liked): | |
| if liked: | |
| md_text = "text:\n" + value + "\n" + "liked" | |
| else: | |
| md_text = "text:\n" + value + "\n" + "disliked" | |
| timestamp = int(time.time() * 1000) | |
| md_filename = f"./resource/feedback_Text_{timestamp}.md" | |
| with open(md_filename, "w", encoding="utf-8") as file: | |
| file.write(md_text) | |
| class knowledge_class: | |
| def __init__(self): | |
| self.knowledge_data = load_pkl(PKL_FILE) | |
| # print(self.knowledge_data) | |
| self.reference_dict = self.get_reference_dict() | |
| # q_v = self.knowledge_data["2024-10-09-プロジェクト計画で重要視すること"][ | |
| # "vector" | |
| # ][0] | |
| # t_v = self.knowledge_data["2024-10-09-プロジェクト計画で重要視すること"][ | |
| # "vector" | |
| # ] | |
| # cos_sim = max_cosine_similarity(q_v, t_v) | |
| # print(cos_sim) | |
| def get_reference_dict(self): | |
| reference_dict = {} | |
| for ref_name, ref_dict in self.knowledge_data.items(): | |
| title = ref_dict["title"] | |
| if ref_name not in reference_dict: | |
| reference_dict[ref_name] = {} | |
| reference_dict[ref_name]["original_text"] = ( | |
| f"### {title}\n" | |
| + ref_dict["text"] | |
| + f"\n\n[URL]({ref_dict['url']})" | |
| ) | |
| reference_dict[ref_name]["summary"] = ref_dict["summary"] | |
| reference_dict[ref_name]["audio"] = ref_dict["audio"] | |
| else: | |
| print(f"overlap ref_name: {ref_name}") | |
| raise | |
| return reference_dict | |
| def get_new_knowledge(self): | |
| self.knowledge_data = get_old_blog_content(self.knowledge_data) | |
| self.knowledge_data = get_new_blog_content(self.knowledge_data) | |
| save_pkl(PKL_FILE, self.knowledge_data) | |
| self.knowledge_data = load_pkl(PKL_FILE) | |
| self.reference_dict = self.get_reference_dict() | |
| print("PKLファイルの更新が完了しました。") | |
| def find_top_info(self, question_vector): | |
| results = [] | |
| for idx, k_dict in self.knowledge_data.items(): | |
| idx_vector = k_dict.get("vector", [0]) | |
| cos_sim = max_cosine_similarity(question_vector, idx_vector) | |
| results.append((idx, cos_sim)) | |
| results_sorted = sorted(results, key=lambda x: x[1], reverse=True) | |
| top2 = [] | |
| retrieve_text = "\n" | |
| retrieve_title = "" | |
| for res_sort in results_sorted: | |
| if res_sort[0] not in top2: | |
| top2.append(res_sort[0]) | |
| retrieve_text += f"情報 {len(top2)} : \n" | |
| retrieve_text += ( | |
| f"- タイトル:{self.knowledge_data[res_sort[0]]['title']} \n" | |
| ) | |
| retrieve_title += f"{len(top2)}. {res_sort[0]} \n" | |
| retrieve_text += ( | |
| f"- コンテンツ:\n {self.knowledge_data[res_sort[0]]['text']} \n" | |
| ) | |
| retrieve_text += f"- 作者の文体と思考パターン:\n {self.knowledge_data[res_sort[0]]['style']} \n" | |
| retrieve_text += f"- 質問と類似度:{res_sort[1]} \n" | |
| retrieve_text += "\n" | |
| if len(top2) > 1: | |
| break | |
| return retrieve_text, retrieve_title | |
| def get_answer(self, question_text, creative_prompt, full_prompt, temperature): | |
| # get similar info | |
| question_vector = get_embedding(question_text)[0] | |
| info_text, info_title = self.find_top_info(question_vector) | |
| user_prompt = QA_Prompt_template.format( | |
| q_text=question_text, | |
| r_text=info_text, | |
| c_text=creative_prompt, | |
| s_text=full_prompt, | |
| ) | |
| answer_text = generate_item( | |
| user_prompt, SYS_Prompt, model="gpt-4.1", temperature=temperature | |
| ) | |
| md_text = user_prompt + "\n 応答: \n" + answer_text | |
| timestamp = int(time.time() * 1000) | |
| md_filename = f"./resource/QA_Text_{timestamp}.md" | |
| audio_filename = f"./resource/QA_Audio_{timestamp}.mp3" | |
| with open(md_filename, "w", encoding="utf-8") as file: | |
| file.write(md_text) | |
| return answer_text, info_title, audio_filename | |
| def get_chat_answer(self, chat_list, creative_flag, full_flag): | |
| print("all chat list:", chat_list) | |
| print("creative_flag:", creative_flag) | |
| print("full_flag:", full_flag) | |
| creative_prompt = Common_text | |
| temperature = 1.0 | |
| if creative_flag: | |
| creative_prompt = Creative_text | |
| temperature = 1.2 | |
| full_prompt = Short_text | |
| if full_flag: | |
| full_prompt = Full_text | |
| chat_history = [] | |
| for turn in chat_list: | |
| if turn["role"] == "assistant" and turn["metadata"]: | |
| continue | |
| else: | |
| chat_history.append(turn) | |
| chat_history = chat_history[-5:] | |
| if chat_history[-1]["role"] != "user" or (not chat_history[-1]["content"]): | |
| return None, None, None | |
| if len(chat_history) == 1: | |
| answer_text, info_title, audio_filename = self.get_answer( | |
| chat_history[-1]["content"], | |
| creative_prompt, | |
| full_prompt, | |
| temperature, | |
| ) | |
| return answer_text, info_title, audio_filename | |
| # get history | |
| chat_history_str = "" | |
| for msg in chat_history[:-1]: | |
| chat_history_str += ( | |
| "\n" if chat_history_str else "" | |
| ) + f"{msg['role']}: {msg['content']}" | |
| print("chat_history_str:", chat_history_str) | |
| new_query = chat_history[-1]["content"] | |
| print("new_query:", new_query) | |
| rw_prompt = REWRITE_Prompt.format( | |
| chat_history=chat_history_str, question=new_query | |
| ) | |
| # if "test" in chat_history[-1]["content"]: | |
| # return ( | |
| # f"test_ui:{chat_history[-1]['content']}", | |
| # "1. info_text1 \n 2. info_text2 \n", | |
| # "./resource/2025-03-03-機嫌良く働くと仕事は上手く進む.mp3", | |
| # ) | |
| rewrite_question = generate_item(rw_prompt, REWRITE_SYS_Prompt, model="gpt-4.1") | |
| print("rewrite_question:", rewrite_question) | |
| # prompt = DEFAULT_TEMPLATE.format(chat_history=chat_history, question=query) | |
| # get rewrite question | |
| # get similar info | |
| question_vector = get_embedding(rewrite_question)[0] | |
| info_text, info_title = self.find_top_info(question_vector) | |
| user_prompt = QA_chat_Prompt_template.format( | |
| h_text=chat_history_str, | |
| q_text=rewrite_question, | |
| r_text=info_text, | |
| c_text=creative_prompt, | |
| s_text=full_prompt, | |
| ) | |
| answer_text = generate_item( | |
| user_prompt, | |
| SYS_Prompt, | |
| model="gpt-4.1", | |
| temperature=temperature, | |
| ) | |
| md_text = user_prompt + "\n 応答: \n" + answer_text | |
| timestamp = int(time.time() * 1000) | |
| md_filename = f"./resource/QA_Text_{timestamp}.md" | |
| audio_filename = f"./resource/QA_Audio_{timestamp}.mp3" | |
| with open(md_filename, "w", encoding="utf-8") as file: | |
| file.write(md_text) | |
| return answer_text, info_title, audio_filename | |
| # kc_class = knowledge_class() | |
| # print(kc_class.get_new_knowledge()) | |