File size: 8,010 Bytes
217acfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import uuid

# 定义了用于Wirter yield的数据类型,同时也是前端展示的“关键点”消息
class KeyPointMsg(dict):
    def __init__(self, title='', subtitle='', prompt_name=''):
        super().__init__()
        if not title and not subtitle and prompt_name:
            pass
        elif title and subtitle and not prompt_name:
            pass
        else:
            raise ValueError('Either title and subtitle or prompt_name must be provided')
        
        self.update({
            'id': str(uuid.uuid4()),
            'title': title,
            'subtitle': subtitle,
            'prompt_name': prompt_name,
            'finished': False
        })

    def set_finished(self):
        assert not self['finished'], 'finished flag is already set'
        self['finished'] = True
        return self # 返回self,方便链式调用

    def is_finished(self):
        return self['finished']
    
    def is_prompt(self):
        return bool(self.prompt_name)
    
    def is_title(self):
        return bool(self.title)
    
    @property
    def id(self):
        return self['id']
    
    @property
    def title(self):
        return self['title']
    
    @property
    def subtitle(self):
        return self['subtitle']
    
    @property
    def prompt_name(self):
        prompt_name = self['prompt_name']
        if len(prompt_name) >= 10:
            return prompt_name[:10] + '...'
        return prompt_name


import re
from difflib import Differ

# 后续考虑采用现成的库实现,目前逻辑过于繁琐,而且太慢了
def detect_max_edit_span(a, b):
    diff = Differ().compare(a, b)

    l = 0
    r = 0
    flag_count_l = True

    for tag in diff:
        if tag.startswith(' '):
            if flag_count_l:
                l += 1
            else:
                r += 1
        else:
            flag_count_l = False
            r = 0

    return l, -r   

def split_text_by_separators(text, separators, keep_separators=True):
    """

    将文本按指定的分隔符分割为段落

    Args:

        text: 要分割的文本

        separators: 分隔符列表

        keep_separators: 是否在结果中保留分隔符,默认为True

    Returns:

        包含分割后段落的列表

    """
    pattern = f'({"|".join(map(re.escape, separators))}+)'
    chunks = re.split(pattern, text)
    
    paragraphs = []
    current_para = []
    
    for i in range(0, len(chunks), 2):
        content = chunks[i]
        separator = chunks[i + 1] if i + 1 < len(chunks) else ''
        
        current_para.append(content)
        if keep_separators and separator:
            current_para.append(separator)
            
        if content.strip():
            paragraphs.append(''.join(current_para))
            current_para = []
    
    return paragraphs

def split_text_into_paragraphs(text, keep_separators=True):
    return split_text_by_separators(text, ['\n'], keep_separators)

def split_text_into_sentences(text, keep_separators=True):
    return split_text_by_separators(text, ['\n', '。', '?', '!', ';'], keep_separators)

def run_and_echo_yield_func(func, *args, **kwargs):
    echo_text = ""
    all_messages = []
    for messages in func(*args, **kwargs):
        all_messages.append(messages)
        new_echo_text = "\n".join(f"{msg['role']}:\n{msg['content']}" for msg in messages)
        if new_echo_text.startswith(echo_text):
            delta_echo_text = new_echo_text[len(echo_text):]
        else:
            echo_text = ""
            print('\n--------------------------------')
            delta_echo_text = new_echo_text

        print(delta_echo_text, end="")
        echo_text = echo_text + delta_echo_text
    return all_messages

def run_yield_func(func, *args, **kwargs):
    gen = func(*args, **kwargs)
    try:
        while True:
            next(gen)
    except StopIteration as e:
        return e.value

def split_text_into_chunks(text, max_chunk_size, min_chunk_n, min_chunk_size=1, max_chunk_n=1000):
    def split_paragraph(para):
        mid = len(para) // 2
        split_pattern = r'[。?;]'
        split_points = [m.end() for m in re.finditer(split_pattern, para)]
        
        if not split_points:
            raise Exception("没有找到分割点!")
        
        closest_point = min(split_points, key=lambda x: abs(x - mid))
        if not para[:closest_point].strip() or not para[closest_point:].strip():
            raise Exception("没有找到分割点!")
        
        return para[:closest_point], para[closest_point:]

    paragraphs = split_text_into_paragraphs(text)

    assert max_chunk_n >= 1, "max_chunk_n必须大于等于1"
    assert sum(len(p) for p in paragraphs) >= min_chunk_size, f"分割时,输入的文本长度小于要求的min_chunk_size:{min_chunk_size}"
    count = 0 # 防止死循环
    while len(paragraphs) > max_chunk_n or min(len(p) for p in paragraphs) < min_chunk_size:
        assert (count:=count+1) < 1000, "分割进入死循环!"

        # 找出相邻chunks中和最小的两个进行合并
        min_sum = float('inf')
        min_i = 0

        for i in range(len(paragraphs) - 1):
            curr_sum = len(paragraphs[i]) + len(paragraphs[i + 1])
            if curr_sum < min_sum:
                min_sum = curr_sum
                min_i = i
                
        # 合并这两个chunks
        paragraphs[min_i:min_i + 2] = [''.join(paragraphs[min_i:min_i + 2])]

    while len(paragraphs) < min_chunk_n or max(len(p) for p in paragraphs) > max_chunk_size:
        assert (count:=count+1) < 1000, "分割进入死循环!"
        longest_para_i = max(range(len(paragraphs)), key=lambda i: len(paragraphs[i]))
        part1, part2 = split_paragraph(paragraphs[longest_para_i])
        if len(part1) < min_chunk_size or len(part2) < min_chunk_size or len(paragraphs) + 1 > max_chunk_n:
            raise Exception("没有找到合适的分割点!")
        paragraphs[longest_para_i:longest_para_i+1] = [part1, part2]
    
    return paragraphs

def test_split_text_into_chunks():
    # Test case 1: Simple paragraph splitting
    text1 = "这是第一段。这是第二段。这是第三段。"
    result1 = split_text_into_chunks(text1, max_chunk_size=10, min_chunk_n=3)
    print("Test 1 result:", result1)
    assert len(result1) == 3, f"Expected 3 chunks, got {len(result1)}"


    # Test case 2: Long paragraph splitting
    text2 = "这是一个很长的段落,包含了很多句子。它应该被分割成多个小块。这里有一些标点符号,比如句号。还有问号?以及分号;这些都可以用来分割文本。"
    result2 = split_text_into_chunks(text2, max_chunk_size=20, min_chunk_n=4)
    print("Test 2 result:", result2)
    assert len(result2) >= 4, f"Expected at least 4 chunks, got {len(result2)}"
    assert all(len(chunk) <= 20 for chunk in result2), "Some chunks are longer than max_chunk_size"

    # Test case 3: Text with newlines
    text3 = "第一段。\n\n第二段。\n第三段。\n\n第四段很长,需要被分割。这是第四段的继续。"
    result3 = split_text_into_chunks(text3, max_chunk_size=15, min_chunk_n=5)
    print("Test 3 result:", result3)
    assert len(result3) >= 5, f"Expected at least 5 chunks, got {len(result3)}"
    assert all(len(chunk) <= 15 for chunk in result3), "Some chunks are longer than max_chunk_size"

    print("All tests passed!")

if __name__ == "__main__":
    print(detect_max_edit_span("我吃西红柿", "我不喜欢吃西红柿"))
    print(detect_max_edit_span("我吃西红柿", "不喜欢吃西红柿"))
    print(detect_max_edit_span("我吃西红柿", "我不喜欢吃"))
    print(detect_max_edit_span("我吃西红柿", "你不喜欢吃西瓜"))

    test_split_text_into_chunks()