File size: 3,648 Bytes
217acfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import re
from prompts.chat_utils import chat, log
from prompts.pf_parse_chat import parse_chat
from prompts.prompt_utils import load_text, match_code_block

def parser(response_msgs):
    content = response_msgs.response
    blocks = match_code_block(content)
    if blocks:
        concat_blocks = "\n".join(blocks)
        if concat_blocks.strip():
            content = concat_blocks
    return content


def clean_txt_content(content):
    """Remove comments and trim empty lines from txt content"""
    lines = []
    for line in content.split('\n'):
        if not line.startswith('//'):
            lines.append(line)
    return '\n'.join(lines).strip()


def load_prompt(dirname, name):
    txt_path = os.path.join(dirname, f"{name}.txt")
    text = load_text(txt_path)

    return text

def parse_prompt(text, **kwargs):
    """

        从text中解析PromptMessages。

        对于传入的key-values, key可以多也可以少。

        少的key和value为空的那轮对话会被删除。

        多的key不会管。

    """
    content = clean_txt_content(text)

    # Find all format keys in content using regex
    format_keys = set(re.findall(r'\{(\w+)\}', content))
    
    formatted_kwargs = {k: kwargs.get(k, '__delete__') or '__delete__' for k in format_keys}
    formatted_kwargs = {k: f"```\n{v.strip()}\n```" for k, v in formatted_kwargs.items()}
    prompt = content.format(**formatted_kwargs) if format_keys else content
    messages = parse_chat(prompt)
    for i in range(len(messages)-2, -1, -1):
        if '__delete__' in messages[i]['content']:
            assert messages[i]['role'] == 'user' and messages[i+1]['role'] == 'assistant', "__delete__ must be in user's message"
            messages.pop(i)
            messages.pop(i)
    
    return messages


def parse_input_keys(text):
    # Use regex to find the input keys line and parse keys
    match = re.search(r'//\s*输入:(.*?)(?:\n|$)', text)
    if not match:
        return []
        
    keys_str = match.group(1).strip()
        
    keys = [k.strip() for k in keys_str.split(',') if k.strip()]
    
    return keys

def main(model, dirname, user_prompt_text, **kwargs):
    # Load system prompt
    system_prompt = parse_prompt(load_prompt(dirname, "system_prompt"), **kwargs)
    
    load_from_file_flag = False
    if os.path.exists(os.path.join(dirname, user_prompt_text)):
        user_prompt_text = load_prompt(dirname, user_prompt_text)
        load_from_file_flag = True
    else:
        if not re.search(r'^user:\n', user_prompt_text, re.MULTILINE):
            user_prompt_text = f"user:\n{user_prompt_text}"
        
    user_prompt = parse_prompt(user_prompt_text, **kwargs)
    
    context_input_keys = parse_input_keys(user_prompt_text)
    if not context_input_keys:
        assert not load_from_file_flag, "从本地文件加载Prompt时,本地文件中注释必须指明输入!"
        context_kwargs = kwargs
    else:
        context_kwargs = {k: kwargs[k] for k in context_input_keys}
        assert all(context_kwargs.values()), "Missing required context keys"
    
    context_prompt = parse_prompt(load_prompt(dirname, "context_prompt"), **context_kwargs)
    
    # Combine all prompts
    final_prompt = system_prompt + context_prompt + user_prompt
    
    # Chat and parse results
    for response_msgs in chat(final_prompt, None, model, parse_chat=False):
        text = parser(response_msgs)
        ret = {'text': text, 'response_msgs': response_msgs}
        yield ret

    return ret