| | import random
|
| | import json
|
| | from difflib import SequenceMatcher
|
| | import pandas as pd
|
| |
|
| | def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
|
| | """
|
| | 生成带有n个样例的提示。
|
| |
|
| | Args:
|
| | text: 包含文本描述的 Series。
|
| | label: 包含三元组列表的 Series。
|
| | n: 要提取的样例数量。
|
| | start_index: 样例的起始索引。
|
| | end_index: 样例的结束索引。
|
| |
|
| | Returns:
|
| | 一个字符串,包含n个样例的提示,格式为text_prompt[i]+'\\n'+triple_prompt[i]。
|
| | 如果n大于可用样例数量,则返回所有可用样例。
|
| | """
|
| |
|
| | text_len = len(text)
|
| | label_len = len(label)
|
| | end_index = min(end_index, text_len, label_len)
|
| |
|
| | if start_index >= end_index:
|
| | return "起始索引大于或等于结束索引,无法生成样例。"
|
| |
|
| | available_examples = end_index - start_index
|
| | n = min(n, available_examples)
|
| |
|
| | prompt = ""
|
| |
|
| | random_indices = random.sample(range(start_index, end_index), n)
|
| |
|
| | for i in random_indices:
|
| | text_prompt = text.iloc[i]
|
| | triple_prompt = label.iloc[i]
|
| | prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
|
| |
|
| | return prompt
|
| |
|
| | def generate_prompt_with_best_matches(text_series, label_series, query_text, n=3, start_index=500, end_index=1000):
|
| | """
|
| | 基于相似度匹配生成渐进式提示
|
| |
|
| | 参数:
|
| | text: 文本数据Series
|
| | label: 三元组标签数据Series
|
| | query_text: 需要生成提示的查询文本
|
| | n: 最大样例数量
|
| | start_index: 候选样例起始索引
|
| | end_index: 候选样例结束索引
|
| |
|
| | 返回:
|
| | 渐进式提示字符串 (1个样例、2个样例、3个样例...)
|
| | """
|
| | query_text_path = './data/text_retrieval_results.json'
|
| | with open(query_text_path, 'r', encoding='utf-8') as f:
|
| | query_text_data = json.load(f)
|
| |
|
| | for item in query_text_data:
|
| | if item['query_text'] == query_text:
|
| | matched_texts = item['matched_texts']
|
| | break
|
| | matched_texts = matched_texts[:n]
|
| |
|
| | context_list = [text['context'] for text in matched_texts]
|
| |
|
| | prompt = ""
|
| |
|
| | for context in context_list:
|
| |
|
| | for idx, text_item in text_series.items():
|
| | if context in text_item:
|
| |
|
| | position = text_series.index.get_loc(idx)
|
| | label = label_series.iloc[position]
|
| |
|
| |
|
| |
|
| | prompt += f"{text_item}\n{label}\n\n"
|
| | break
|
| |
|
| | return prompt.strip()
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | train_data = pd.read_json('./data/train_triples.json')
|
| | text = train_data['text']
|
| | label = train_data['triple_list']
|
| | query_text_path = './data/text_retrieval_results.json'
|
| | with open(query_text_path, 'r', encoding='utf-8') as f:
|
| | query_text_data = json.load(f)
|
| | query_text = query_text_data[3]['query_text']
|
| | print(query_text)
|
| | print("--------------------------------")
|
| | prompt = generate_prompt_with_best_matches(text, label, query_text, n=3, start_index=500, end_index=1000)
|
| | print(prompt)
|
| |
|
| |
|