File size: 2,333 Bytes
badcf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# 生成prompt
import json

def generate_prompt(type, json_path, example_num=3):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    prompts = []
    for item in data:
        test_query = item['test_query']
        matched_samples = item['matched_samples'][:example_num]  # 控制例子数量
        
        # 构建prompt
        prompt = f'请根据给定的文本回答问题,\n例子如下:\n'
        
        for sample in matched_samples:
            prompt += f'给定文本:"{sample["context"]}"。\n问题: "{sample["question"]}"。\n答案:"{sample["answer"]}"\n\n'
        if type == 'yes_no':
            prompt += f'给定文本:"{test_query["text"]}"。\n问题:"{test_query["question"]}"\n请用Yes或No直接回答。'
        elif type == 'factoid':
            prompt += f'给定文本:"{test_query["text"]}"。\n问题:"{test_query["question"]}"\n请直接回答该问题。'
        prompts.append(prompt)
    
    return prompts

def generate_prompt_cot(type, json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    prompts = []
    for item in data:
        test_query = item['test_query']
        
        # 构建prompt
        prompt = f'请根据给定的文本回答问题,\n'
        if type == 'yes_no':
            prompt += f'给定文本:"{test_query["text"]}"。\n问题:"{test_query["question"]}"\n请先回答Yes或No,随后给出你的推理依据'
        elif type == 'factoid':
            prompt += f'给定文本:"{test_query["text"]}"。\n问题:"{test_query["question"]}"\n请首先回答该问题,并给出你的推理依据。'
        prompts.append(prompt)
    
    return prompts


if __name__ == "__main__":
    # 使用示例
    prompts = generate_prompt_cot('yes_no', './data/knn_yes_no.json')
    #打印前两个
    for i, prompt in enumerate(prompts):
        print(f'Prompt {i+1}:\n{prompt}\n')
        if i >= 1:
            break

    prompts = generate_prompt('factoid', './data/knn_factoid.json', example_num=1)
    #打印前两个
    print('knnnnnnnnnnnnnnnnnnnn')
    for i, prompt in enumerate(prompts):
        print(f'Prompt {i+1}:\n{prompt}\n')
        if i >= 1:
            break