import random import json from difflib import SequenceMatcher import pandas as pd def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000): """ 生成带有n个样例的提示。 Args: text: 包含文本描述的 Series。 label: 包含三元组列表的 Series。 n: 要提取的样例数量。 start_index: 样例的起始索引。 end_index: 样例的结束索引。 Returns: 一个字符串,包含n个样例的提示,格式为text_prompt[i]+'\\n'+triple_prompt[i]。 如果n大于可用样例数量,则返回所有可用样例。 """ text_len = len(text) label_len = len(label) end_index = min(end_index, text_len, label_len) if start_index >= end_index: return "起始索引大于或等于结束索引,无法生成样例。" available_examples = end_index - start_index n = min(n, available_examples) prompt = "" # 随机选择n个不同的索引 random_indices = random.sample(range(start_index, end_index), n) for i in random_indices: text_prompt = text.iloc[i] triple_prompt = label.iloc[i] prompt += text_prompt + '\n' + str(triple_prompt) + '\n' return prompt def generate_prompt_with_best_matches(text_series, label_series, query_text, n=3, start_index=500, end_index=1000): """ 基于相似度匹配生成渐进式提示 参数: text: 文本数据Series label: 三元组标签数据Series query_text: 需要生成提示的查询文本 n: 最大样例数量 start_index: 候选样例起始索引 end_index: 候选样例结束索引 返回: 渐进式提示字符串 (1个样例、2个样例、3个样例...) """ query_text_path = './data/text_retrieval_results.json' with open(query_text_path, 'r', encoding='utf-8') as f: query_text_data = json.load(f) # 根据输入的query_text,在query_text_data中找到对应的query_text索引的matched_texts,基于输入n选择提取matched_texts中context的个数 for item in query_text_data: if item['query_text'] == query_text: matched_texts = item['matched_texts'] break matched_texts = matched_texts[:n] # 将matched_texts中的context提取出来 context_list = [text['context'] for text in matched_texts] # 基于context_list中文本匹配输入text和label作为提示 prompt = "" # 遍历每个上下文 for context in context_list: # 遍历整个文本序列 for idx, text_item in text_series.items(): if context in text_item: # 获取对应的标签 position = text_series.index.get_loc(idx) label = label_series.iloc[position] # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@") # print(text_item) # print(label) prompt += f"{text_item}\n{label}\n\n" break return prompt.strip() if __name__ == "__main__": # 读取训练数据 train_data = pd.read_json('./data/train_triples.json') text = train_data['text'] label = train_data['triple_list'] query_text_path = './data/text_retrieval_results.json' with open(query_text_path, 'r', encoding='utf-8') as f: query_text_data = json.load(f) query_text = query_text_data[3]['query_text'] print(query_text) print("--------------------------------") prompt = generate_prompt_with_best_matches(text, label, query_text, n=3, start_index=500, end_index=1000) print(prompt)