Pengfa Li commited on
Commit
674554d
·
verified ·
1 Parent(s): 991723d

Update prompt_generate.py

Browse files
Files changed (1) hide show
  1. prompt_generate.py +30 -100
prompt_generate.py CHANGED
@@ -1,100 +1,30 @@
1
- import random
2
- import json
3
- from difflib import SequenceMatcher
4
- import pandas as pd
5
-
6
- def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
7
- """
8
- 生成带有n个样例的提示。
9
-
10
- Args:
11
- text: 包含文本描述的 Series。
12
- label: 包含三元组列表的 Series。
13
- n: 要提取的样例数量。
14
- start_index: 样例的起始索引。
15
- end_index: 样例的结束索引。
16
-
17
- Returns:
18
- 一个字符串,包含n个样例的提示,格式为text_prompt[i]+'\\n'+triple_prompt[i]。
19
- 如果n大于可用样例数量,则返回所有可用样例。
20
- """
21
-
22
- text_len = len(text)
23
- label_len = len(label)
24
- end_index = min(end_index, text_len, label_len)
25
-
26
- if start_index >= end_index:
27
- return "起始索引大于或等于结束索引,无法生成样例。"
28
-
29
- available_examples = end_index - start_index
30
- n = min(n, available_examples)
31
-
32
- prompt = ""
33
- # 随机选择n个不同的索引
34
- random_indices = random.sample(range(start_index, end_index), n)
35
-
36
- for i in random_indices:
37
- text_prompt = text.iloc[i]
38
- triple_prompt = label.iloc[i]
39
- prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
40
-
41
- return prompt
42
-
43
- def generate_prompt_with_best_matches(text_series, label_series, query_text, n=3, start_index=500, end_index=1000):
44
- """
45
- 基于相似度匹配生成渐进式提示
46
-
47
- 参数:
48
- text: 文本数据Series
49
- label: 三元组标签数据Series
50
- query_text: 需要生成提示的查询文本
51
- n: 最大样例数量
52
- start_index: 候选样例起始索引
53
- end_index: 候选样例结束索引
54
-
55
- 返回:
56
- 渐进式提示字符串 (1个样例、2个样例、3个样例...)
57
- """
58
- query_text_path = './data/text_retrieval_results.json'
59
- with open(query_text_path, 'r', encoding='utf-8') as f:
60
- query_text_data = json.load(f)
61
- # 根据输入的query_text,在query_text_data中找到对应的query_text索引的matched_texts,基于输入n选择提取matched_texts中context的个数
62
- for item in query_text_data:
63
- if item['query_text'] == query_text:
64
- matched_texts = item['matched_texts']
65
- break
66
- matched_texts = matched_texts[:n]
67
- # 将matched_texts中的context提取出来
68
- context_list = [text['context'] for text in matched_texts]
69
- # 基于context_list中文本匹配输入text和label作为提示
70
- prompt = ""
71
- # 遍历每个上下文
72
- for context in context_list:
73
- # 遍历整个文本序列
74
- for idx, text_item in text_series.items():
75
- if context in text_item:
76
- # 获取对应的标签
77
- position = text_series.index.get_loc(idx)
78
- label = label_series.iloc[position]
79
- # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
80
- # print(text_item)
81
- # print(label)
82
- prompt += f"{text_item}\n{label}\n\n"
83
- break
84
-
85
- return prompt.strip()
86
-
87
- if __name__ == "__main__":
88
- # 读取训练数据
89
- train_data = pd.read_json('./data/train_triples.json')
90
- text = train_data['text']
91
- label = train_data['triple_list']
92
- query_text_path = './data/text_retrieval_results.json'
93
- with open(query_text_path, 'r', encoding='utf-8') as f:
94
- query_text_data = json.load(f)
95
- query_text = query_text_data[3]['query_text']
96
- print(query_text)
97
- print("--------------------------------")
98
- prompt = generate_prompt_with_best_matches(text, label, query_text, n=3, start_index=500, end_index=1000)
99
- print(prompt)
100
-
 
1
+ import random
2
+ import json
3
+ from difflib import SequenceMatcher
4
+ import pandas as pd
5
+
6
+ def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
7
+
8
+ text_len = len(text)
9
+ label_len = len(label)
10
+ end_index = min(end_index, text_len, label_len)
11
+
12
+ if start_index >= end_index:
13
+ return "none"
14
+
15
+ available_examples = end_index - start_index
16
+ n = min(n, available_examples)
17
+
18
+ prompt = ""
19
+
20
+ random_indices = random.sample(range(start_index, end_index), n)
21
+
22
+ for i in random_indices:
23
+ text_prompt = text.iloc[i]
24
+ triple_prompt = label.iloc[i]
25
+ prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
26
+
27
+ return prompt
28
+
29
+
30
+