Pengfa Li commited on
Update prompt_generate.py
Browse files- prompt_generate.py +30 -100
prompt_generate.py
CHANGED
|
@@ -1,100 +1,30 @@
|
|
| 1 |
-
import random
|
| 2 |
-
import json
|
| 3 |
-
from difflib import SequenceMatcher
|
| 4 |
-
import pandas as pd
|
| 5 |
-
|
| 6 |
-
def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
prompt = ""
|
| 33 |
-
# 随机选择n个不同的索引
|
| 34 |
-
random_indices = random.sample(range(start_index, end_index), n)
|
| 35 |
-
|
| 36 |
-
for i in random_indices:
|
| 37 |
-
text_prompt = text.iloc[i]
|
| 38 |
-
triple_prompt = label.iloc[i]
|
| 39 |
-
prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
|
| 40 |
-
|
| 41 |
-
return prompt
|
| 42 |
-
|
| 43 |
-
def generate_prompt_with_best_matches(text_series, label_series, query_text, n=3, start_index=500, end_index=1000):
|
| 44 |
-
"""
|
| 45 |
-
基于相似度匹配生成渐进式提示
|
| 46 |
-
|
| 47 |
-
参数:
|
| 48 |
-
text: 文本数据Series
|
| 49 |
-
label: 三元组标签数据Series
|
| 50 |
-
query_text: 需要生成提示的查询文本
|
| 51 |
-
n: 最大样例数量
|
| 52 |
-
start_index: 候选样例起始索引
|
| 53 |
-
end_index: 候选样例结束索引
|
| 54 |
-
|
| 55 |
-
返回:
|
| 56 |
-
渐进式提示字符串 (1个样例、2个样例、3个样例...)
|
| 57 |
-
"""
|
| 58 |
-
query_text_path = './data/text_retrieval_results.json'
|
| 59 |
-
with open(query_text_path, 'r', encoding='utf-8') as f:
|
| 60 |
-
query_text_data = json.load(f)
|
| 61 |
-
# 根据输入的query_text,在query_text_data中找到对应的query_text索引的matched_texts,基于输入n选择提取matched_texts中context的个数
|
| 62 |
-
for item in query_text_data:
|
| 63 |
-
if item['query_text'] == query_text:
|
| 64 |
-
matched_texts = item['matched_texts']
|
| 65 |
-
break
|
| 66 |
-
matched_texts = matched_texts[:n]
|
| 67 |
-
# 将matched_texts中的context提取出来
|
| 68 |
-
context_list = [text['context'] for text in matched_texts]
|
| 69 |
-
# 基于context_list中文本匹配输入text和label作为提示
|
| 70 |
-
prompt = ""
|
| 71 |
-
# 遍历每个上下文
|
| 72 |
-
for context in context_list:
|
| 73 |
-
# 遍历整个文本序列
|
| 74 |
-
for idx, text_item in text_series.items():
|
| 75 |
-
if context in text_item:
|
| 76 |
-
# 获取对应的标签
|
| 77 |
-
position = text_series.index.get_loc(idx)
|
| 78 |
-
label = label_series.iloc[position]
|
| 79 |
-
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
| 80 |
-
# print(text_item)
|
| 81 |
-
# print(label)
|
| 82 |
-
prompt += f"{text_item}\n{label}\n\n"
|
| 83 |
-
break
|
| 84 |
-
|
| 85 |
-
return prompt.strip()
|
| 86 |
-
|
| 87 |
-
if __name__ == "__main__":
|
| 88 |
-
# 读取训练数据
|
| 89 |
-
train_data = pd.read_json('./data/train_triples.json')
|
| 90 |
-
text = train_data['text']
|
| 91 |
-
label = train_data['triple_list']
|
| 92 |
-
query_text_path = './data/text_retrieval_results.json'
|
| 93 |
-
with open(query_text_path, 'r', encoding='utf-8') as f:
|
| 94 |
-
query_text_data = json.load(f)
|
| 95 |
-
query_text = query_text_data[3]['query_text']
|
| 96 |
-
print(query_text)
|
| 97 |
-
print("--------------------------------")
|
| 98 |
-
prompt = generate_prompt_with_best_matches(text, label, query_text, n=3, start_index=500, end_index=1000)
|
| 99 |
-
print(prompt)
|
| 100 |
-
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import json
|
| 3 |
+
from difflib import SequenceMatcher
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
|
| 7 |
+
|
| 8 |
+
text_len = len(text)
|
| 9 |
+
label_len = len(label)
|
| 10 |
+
end_index = min(end_index, text_len, label_len)
|
| 11 |
+
|
| 12 |
+
if start_index >= end_index:
|
| 13 |
+
return "none"
|
| 14 |
+
|
| 15 |
+
available_examples = end_index - start_index
|
| 16 |
+
n = min(n, available_examples)
|
| 17 |
+
|
| 18 |
+
prompt = ""
|
| 19 |
+
|
| 20 |
+
random_indices = random.sample(range(start_index, end_index), n)
|
| 21 |
+
|
| 22 |
+
for i in random_indices:
|
| 23 |
+
text_prompt = text.iloc[i]
|
| 24 |
+
triple_prompt = label.iloc[i]
|
| 25 |
+
prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
|
| 26 |
+
|
| 27 |
+
return prompt
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|