| import random
|
| import json
|
| from difflib import SequenceMatcher
|
| import pandas as pd
|
|
|
| def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
|
| """
|
| 生成带有n个样例的提示。
|
|
|
| Args:
|
| text: 包含文本描述的 Series。
|
| label: 包含三元组列表的 Series。
|
| n: 要提取的样例数量。
|
| start_index: 样例的起始索引。
|
| end_index: 样例的结束索引。
|
|
|
| Returns:
|
| 一个字符串,包含n个样例的提示,格式为text_prompt[i]+'\\n'+triple_prompt[i]。
|
| 如果n大于可用样例数量,则返回所有可用样例。
|
| """
|
|
|
| text_len = len(text)
|
| label_len = len(label)
|
| end_index = min(end_index, text_len, label_len)
|
|
|
| if start_index >= end_index:
|
| return "起始索引大于或等于结束索引,无法生成样例。"
|
|
|
| available_examples = end_index - start_index
|
| n = min(n, available_examples)
|
|
|
| prompt = ""
|
|
|
| random_indices = random.sample(range(start_index, end_index), n)
|
|
|
| for i in random_indices:
|
| text_prompt = text.iloc[i]
|
| triple_prompt = label.iloc[i]
|
| prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
|
|
|
| return prompt
|
|
|
| def generate_prompt_with_best_matches(text_series, label_series, query_text, n=3, start_index=500, end_index=1000):
|
| """
|
| 基于相似度匹配生成渐进式提示
|
|
|
| 参数:
|
| text: 文本数据Series
|
| label: 三元组标签数据Series
|
| query_text: 需要生成提示的查询文本
|
| n: 最大样例数量
|
| start_index: 候选样例起始索引
|
| end_index: 候选样例结束索引
|
|
|
| 返回:
|
| 渐进式提示字符串 (1个样例、2个样例、3个样例...)
|
| """
|
| query_text_path = './data/text_retrieval_results.json'
|
| with open(query_text_path, 'r', encoding='utf-8') as f:
|
| query_text_data = json.load(f)
|
|
|
| for item in query_text_data:
|
| if item['query_text'] == query_text:
|
| matched_texts = item['matched_texts']
|
| break
|
| matched_texts = matched_texts[:n]
|
|
|
| context_list = [text['context'] for text in matched_texts]
|
|
|
| prompt = ""
|
|
|
| for context in context_list:
|
|
|
| for idx, text_item in text_series.items():
|
| if context in text_item:
|
|
|
| position = text_series.index.get_loc(idx)
|
| label = label_series.iloc[position]
|
|
|
|
|
|
|
| prompt += f"{text_item}\n{label}\n\n"
|
| break
|
|
|
| return prompt.strip()
|
|
|
| if __name__ == "__main__":
|
|
|
| train_data = pd.read_json('./data/train_triples.json')
|
| text = train_data['text']
|
| label = train_data['triple_list']
|
| query_text_path = './data/text_retrieval_results.json'
|
| with open(query_text_path, 'r', encoding='utf-8') as f:
|
| query_text_data = json.load(f)
|
| query_text = query_text_data[3]['query_text']
|
| print(query_text)
|
| print("--------------------------------")
|
| prompt = generate_prompt_with_best_matches(text, label, query_text, n=3, start_index=500, end_index=1000)
|
| print(prompt)
|
|
|
|
|