Spaces:
Build error
Build error
| import json | |
| def data2reference( top_k_items, output_n = 3 ): | |
| outputted_items = set() | |
| output_str = "#Reference:\n" | |
| for item in top_k_items: | |
| item_in_life = item["keyword"] | |
| if item_in_life in outputted_items: | |
| continue | |
| name_in_cultivation = item["name_in_cultivation"] | |
| description_in_cultivation = item["description_in_cultivation"] | |
| # output_str += f"name_in_life: {item_in_life}\n" | |
| # output_str += f"name_in_cultivation: {name_in_cultivation}\n" | |
| # output_str += f"description_in_cultivation: {description_in_cultivation}\n\n" | |
| # output with into json format | |
| output_data = { | |
| "name_in_life": item_in_life, | |
| "name_in_cultivation": name_in_cultivation, | |
| "description_in_cultivation": description_in_cultivation | |
| } | |
| output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n" | |
| outputted_items.add(item_in_life) | |
| if len(outputted_items) >= output_n: | |
| break | |
| return output_str.strip() | |
| def data2prompt(query_item , top_k_items): | |
| reference_prompt = data2reference(top_k_items, 3) | |
| task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n" | |
| input_prompt = "# Input:\n" | |
| if "keyword" in query_item: | |
| input_prompt += f"input_name:{query_item['keyword']}\n" | |
| if "description" in query_item: | |
| input_prompt += f"description_in_life:{query_item['description']}\n" | |
| else: | |
| # directly dump query_item | |
| input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n" | |
| CoT_prompt = \ | |
| """Let's think it step by step,以json形式输出逐个字段。包含以下字段 | |
| - name_in_life: 进一步明确要生成描述的物品名称 | |
| - name_in_cultivation_1: 尝试编写物品在修仙界对应的名称 | |
| - description_in_cultivation_1: 尝试编写物品在修仙界对应的描述 | |
| - echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动" | |
| - critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺 | |
| - echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述" | |
| - analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述 | |
| - echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述" | |
| - candidate_descriptions: 从不同角度,输出3次不同的加强后的描述 | |
| - analysis_candidates: 分析各个candidates有什么优点 | |
| - echo_4: "根据analysis_candidates,我将merge出一个最终的描述" | |
| - final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述 | |
| - echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词" | |
| - name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字 | |
| - new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1 | |
| """ | |
| return reference_prompt + task_prompt1 + input_prompt + CoT_prompt | |
| try: | |
| from src.ZhipuClient import ZhipuClient | |
| except: | |
| from ZhipuClient import ZhipuClient | |
| zhipu_client = None | |
| import json | |
| def markdown_to_json(markdown_str): | |
| # 移除Markdown语法中可能存在的标记,如代码块标记等 | |
| if markdown_str.startswith("```json"): | |
| markdown_str = markdown_str[7:-3].strip() | |
| elif markdown_str.startswith("```"): | |
| markdown_str = markdown_str[3:-3].strip() | |
| # 将字符串转换为JSON字典 | |
| json_dict = json.loads(markdown_str) | |
| return json_dict | |
| import re | |
| def forced_extract(input_str, keywords): | |
| result = {key: "" for key in keywords} | |
| for key in keywords: | |
| # 使用正则表达式来查找关键词-值对 | |
| pattern = f'"{key}":\s*"(.*?)"' | |
| match = re.search(pattern, input_str) | |
| if match: | |
| result[key] = match.group(1) | |
| return result | |
| def generate_cultivation_with_rag( query_item, search_result ): | |
| global zhipu_client | |
| if zhipu_client is None: | |
| zhipu_client = ZhipuClient() | |
| prompt = data2prompt(query_item, search_result) | |
| response = zhipu_client.prompt2response(prompt) | |
| try: | |
| json_response = markdown_to_json(response) | |
| except: | |
| keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"] | |
| json_response = forced_extract(response, keyword_list) | |
| if "new_name" not in json_response or json_response["new_name"] == "": | |
| if "name_in_cultivation_1" in json_response: | |
| json_response["new_name"] = json_response["name_in_cultivation_1"] | |
| else: | |
| json_response["new_name"] = "" | |
| if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "": | |
| if "description_in_cultivation_1" in json_response: | |
| json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"] | |
| else: | |
| json_response["final_enhanced_description"] = json_response["new_name"] | |
| return json_response | |
| if __name__ == '__main__': | |
| try: | |
| from src.Database import Database | |
| except: | |
| from Database import Database | |
| db = Database() | |
| try: | |
| from src.Captioner import Captioner | |
| except: | |
| from Captioner import Captioner | |
| import os | |
| os.environ['HTTP_PROXY'] = 'http://localhost:8234' | |
| os.environ['HTTPS_PROXY'] = 'http://localhost:8234' | |
| captioner = Captioner() | |
| test_image = "temp_images/3or47vg0.jpg" | |
| caption_response = captioner.caption(test_image) | |
| # print(caption_response) | |
| search_result = db.search_with_image_name( test_image ) | |
| # print(search_result[0].keys()) | |
| # reference_str = data2reference(search_result, output_n = 3) | |
| # print(reference_str) | |
| seen = set() | |
| keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))] | |
| # print(keywords) | |
| # prompt = data2prompt(caption_response , keywords) | |
| # print(prompt) | |
| from get_major_object import get_major_object, verify_keyword_in_base | |
| json_response = get_major_object(caption_response , keywords) | |
| print(json_response) | |
| print() | |
| in_base_data , alt_data = verify_keyword_in_base(json_response , db) | |
| if alt_data is not None: | |
| result = generate_cultivation_with_rag(alt_data , search_result) | |
| print(result) | |