import argparse import os import re import time import openai import requests from nltk.tokenize import sent_tokenize def replace_jinhao(line, replacement=None): if replacement is not None and re.match(r'^( *#*)*', line)[0].strip() != '': # 全部替换为 replacement return re.sub(r'^( *#*)*', replacement, line, count=1) else: return line def count_jinhao(line): return re.match(r'^( *#*)*', line)[0].count('#') def is_english(strs): for _char in strs: if '\u4e00' <= _char <= '\u9fa5': return False return True def sentence_split_en(line): res = filter(lambda l: l.strip() != '', sent_tokenize(line)) res = list(map(lambda l: l.strip(), res)) idx = 0 while idx < len(res) - 1: if len(res[idx]) < 10: # 句子少于10个字符,向后合并。 res[idx + 1] = res[idx] + ' ' + res[idx + 1] res.pop(idx) else: idx += 1 return res def sentence_split_zh(line): res = [] pre_idx = 0 for i in range(1, len(line)-1): if line[i] != '。': # 按句号切分 continue if line[i - 1] in '0123456789': # ocr error continue if line[i + 1] in '0123456789': # ocr error continue if len(line[pre_idx: i + 1].strip()) <= 5: continue res.append(line[pre_idx: i + 1].strip()) pre_idx = i + 1 if pre_idx < len(line): res.append(line[pre_idx:]) return res def sentence_split(line): if is_english(line): return sentence_split_en(line) else: return sentence_split_zh(line) def sentence_truncation(line, head_limit=15, tail_limit=15): total_limit = head_limit+tail_limit if is_english(line): len_factor = 10 else: len_factor = 1 if 0 < total_limit * len_factor < len(line): _head_limit = head_limit * len_factor _tail_limit = len(line) - tail_limit * len_factor line = line[:_head_limit] + line[_tail_limit:] return line def text2sentence(lines, replacement=None, head_limit=15, tail_limit=15): """ :param lines: :param replacement: in [None, '# ', '']. replace the jinhao prefix of one line. None means no replacement. :param head_limit: :param tail_limit: :return: """ res = [] for idx, line in enumerate(lines): res.extend(sentence_split(line)) for idx, temp in enumerate(res): _temp = replace_jinhao(temp, '# ') _temp = sentence_truncation(_temp, head_limit, tail_limit) _temp = replace_jinhao(_temp, f"{'#'*count_jinhao(temp)} ") _temp = replace_jinhao(_temp, replacement) res[idx] = _temp+'\n' return res PROMPT = ('You are an assistant good at reading and formatting documents, and you are also skilled at distinguishing ' 'the semantic and logical relationships of sentences between document context. The following is a text that ' 'has already been divided into sentences. Each line is formatted as: "{line number} @ {sentence content}". ' 'You need to segment this text based on semantics and format. There are multiple levels of granularity for ' 'segmentation, the higher level number means the finer granularity of the segmentation. Please ensure that ' 'each Level One segment is semantically complete after segmentation. A Level One segment may contain ' 'multiple Level Two segments, and so on. Please incrementally output the starting line numbers of each level ' 'of segments, and determine the level of the segment, as well as whether the content of the sentence at the ' 'starting line number can be used as the title of the segment. Finally, output a list format result, ' 'where each element is in the format of: "{line number}, {segment level}, {be a title?}".' '\n\n>>> Input text:\n') def index_format(idx, line): return f'{idx} @ {line}' def points2clip(points, start_idx, end_idx): """ :param points: [a, b, c, d] :param start_idx: x :param end_idx: y assert: x <= a < b < c < d < y return [[x, a], [a, b], [b, c], [c, d], [d, y]] """ clips = [] pre_p = start_idx for p in points: if p == start_idx or p >= end_idx: continue clips.append([pre_p, p]) pre_p = p clips.append([pre_p, end_idx]) return clips # parse answer string to list of chunking points def parse_answer_chunking_point(answer_string, max_level): local_chunk_points = {level_dict_en[i]: [] for i in range(max_level)} for line in answer_string.split('\n'): [point, level, _] = line.split(', ') if level in local_chunk_points: local_chunk_points[level].append(int(point)) res = list(local_chunk_points.values()) for idx, _ in enumerate(res): if len(_) == 0: continue keep_idx = list(filter(lambda i: _[i] > _[i-1], range(1, len(_)))) res[idx] = [_[0]] + list(map(lambda i: _[i], keep_idx)) return res level_dict_en = { 0: 'Level One', 1: 'Level Two', 2: 'Level Three', 3: 'Level Four', 4: 'Level Five', 5: 'Level Six', 6: 'Level Seven', 7: 'Level Eight', 8: 'Level Nine', 9: 'Level Ten', } def check_answer_point(first_level_points, start_idx, end_idx): print('parsed_answer:', first_level_points, start_idx, end_idx) if len(first_level_points) > 0 and first_level_points[0] < start_idx: return False for idx in range(1, len(first_level_points)): p = first_level_points[idx] if p <= first_level_points[idx-1] or p > end_idx: return False return True def build_residual_lines(lines, global_chunk_points, start_idx, window_size, recurrent_type): if recurrent_type in [0, 1]: return [] assert recurrent_type == 2, f'Not implemented for recurrent_type: {recurrent_type}' last_first_point = 0 if len(global_chunk_points[0]) > 0: last_first_point = global_chunk_points[0][-1] current_second_points = filter(lambda p: p >= last_first_point, global_chunk_points[1]) temp_second_clips = points2clip(current_second_points, last_first_point, start_idx) # 每个一级片段中,最多保留5个二级片段,前2后3,每个二级片段最多20行。 pre_seg_num, post_seg_num, line_num = 2, 3, 20 while True: residual_second_clips = temp_second_clips if len(temp_second_clips) > (pre_seg_num + post_seg_num): residual_second_clips = ( temp_second_clips[:pre_seg_num] + temp_second_clips[len(temp_second_clips)-post_seg_num:] ) residual_lines = [] for rsc in residual_second_clips: # 每个二级片段最多保留20行 pre_sent_idx, post_sent_idx = rsc[0], min(rsc[1], rsc[0]+line_num) residual_lines.extend(lines[pre_sent_idx: post_sent_idx]) if len('\n'.join(residual_lines)) < window_size/2: print(residual_lines) return residual_lines # 超出推理窗口一半,则需要减少残余输入。前减1,后减1,行数减5。 pre_seg_num, post_seg_num, line_num = pre_seg_num-1, post_seg_num-1, line_num-5 # 最小设置的情况下仍然超出窗口一半,则不添加残余输入。 if pre_seg_num * post_seg_num * line_num <= 0: return [] def union_chunk_points(local_chunk_points, global_chunk_points, max_idx): for idx, level in enumerate(global_chunk_points): global_chunk_points[idx].extend(filter(lambda p: p < max_idx, local_chunk_points[idx])) return global_chunk_points class HiChunkInferenceEngine: def __init__(self, window_size, line_max_len, max_level, prompt): self.window_size = window_size self.line_max_len = line_max_len self.max_level = max_level self.prompt = prompt self.base_url = os.environ.get("OPENAI_BASE_URL", "http://localhost:8000") self.llm = openai.Client(base_url=f"{self.base_url}/v1", api_key="[empty]") def init_chunk_points(self): global_chunk_points = [[] for i in range(self.max_level)] return global_chunk_points def build_input_instruction(self, prompt, global_start_idx, sentences, window_size, residual_lines=None): """ Build input instruction for once inference :param prompt: prompt :param global_start_idx: global start index of input sentences :param sentences: global sentences :param window_size: :param residual_lines: :return: """ q = prompt # concat residual lines if exists residual_index = 0 while residual_lines is not None and residual_index < len(residual_lines): line_text = index_format(residual_index, residual_lines[residual_index]) temp_text = q + line_text q = temp_text residual_index += 1 assert self.count_length(q) <= window_size, 'residual lines exceeds window size' local_start_idx = 0 cur_token_num = self.count_length(q) end = False # concat sentences until reach window_size while global_start_idx < len(sentences): line_text = index_format(local_start_idx + residual_index, sentences[global_start_idx]) temp_text = q + line_text line_token_num = self.count_length(line_text) if cur_token_num + line_token_num > window_size: break cur_token_num += line_token_num q = temp_text local_start_idx += 1 global_start_idx += 1 if global_start_idx == len(sentences): end = True return q, end, local_start_idx def call_llm(self, input_text): response = self.llm.chat.completions.create( model='HiChunk', messages=[{'role': 'user', 'content': input_text}], temperature=0.0, max_tokens=4096, extra_body={ "chat_template_kwargs": {"add_generation_prompt": True, "enable_thinking": False} } ) return response.choices[0].message.content def count_length(self, text): response = requests.post( url=f'{self.base_url}/tokenize', json={'model': 'HiChunk', 'prompt': text} ).json() return response['count'] def pre_process(self, document): lines = map(lambda l: l.strip(), document.split('\n')) lines = list(filter(lambda l: len(l) != 0, lines)) origin_lines = text2sentence(lines, None, -1, 0) # 原始文本行,不截断 input_lines = text2sentence(lines, '', self.line_max_len, 0) # 原始文本行,截断长度为self.line_max_len return input_lines, origin_lines @staticmethod def post_process(origin_lines, global_chunk_points): origin_lines_remove_jinhao = [replace_jinhao(l, '') for l in origin_lines] total_points = sorted( [[__, i + 1] for i, _ in enumerate(global_chunk_points) for __ in _], key=lambda p: p[0] ) splits = [] pre_level, pre_point = 1, 0 for i, [p, level] in enumerate(total_points): if p == 0: continue splits.append([''.join(origin_lines_remove_jinhao[pre_point: p]), pre_level]) pre_level = level pre_point = p splits.append([''.join(origin_lines_remove_jinhao[pre_point:]), pre_level]) return splits def iterative_inf(self, lines, recurrent_type=1): error_count, start_idx = 0, 0 raw_qa, residual_lines = [], [] global_chunk_points = self.init_chunk_points() while start_idx < len(lines): residual_sent_num = len(residual_lines) question, is_end, question_sent_num = self.build_input_instruction( self.prompt, start_idx, lines, self.window_size, residual_lines ) question_token_num = self.count_length(question) print('question len:', len(question), question_token_num) start_time = time.time() answer = self.call_llm(question) inf_time = time.time() - start_time answer_token_num = self.count_length(answer) print('answer:', answer) print('answer len:', answer_token_num) tmp = { 'question': question, 'answer': answer, 'start_idx': start_idx, 'end_idx': start_idx+question_sent_num, 'residual_sent_num': residual_sent_num, 'time': inf_time, 'question_token_num': question_token_num, 'answer_token_num': answer_token_num, } # 解析输出结果,将局部句子序号转化为全局的句子序号 try: local_chunk_points = parse_answer_chunking_point(answer, self.max_level) if not check_answer_point(local_chunk_points[0], 0, question_sent_num+residual_sent_num-1): print('###########check error##############') tmp['status'] = 'check error' local_chunk_points = self.init_chunk_points() local_chunk_points[0].append(start_idx) error_count += 1 else: tmp['status'] = 'check ok' print('#############check ok################') for idx, points in enumerate(local_chunk_points): filter_points = filter(lambda p: p >= residual_sent_num, points) # p-residual_sent_num+start_idx 将局部推理的句子序号转化为全局文档的句子序号 local_chunk_points[idx] = [p - residual_sent_num + start_idx for p in filter_points] except: print('##########parsed error################') tmp['status'] = 'parse error' local_chunk_points = self.init_chunk_points() local_chunk_points[0].append(start_idx) error_count += 1 raw_qa.append(tmp) print('local_chunk_points:', local_chunk_points) if is_end: # 全文档推理结束 start_idx += question_sent_num global_chunk_points = union_chunk_points(local_chunk_points, global_chunk_points, start_idx) break if len(local_chunk_points[0]) > 1 and recurrent_type in [1, 2]: # 多个一级片段,丢弃掉本次结果的最后一个一级片段 # 从最后一个一级片段开始构建下次迭代的输入 start_idx = local_chunk_points[0][-1] global_chunk_points = union_chunk_points(local_chunk_points, global_chunk_points, start_idx) residual_lines = [] else: # 局部推理结果中只有一个一级片段 # 从上次迭代输入的最后一行开始构建下次迭代的输入,并带上上次迭代的残余行 start_idx += question_sent_num global_chunk_points = union_chunk_points(local_chunk_points, global_chunk_points, start_idx) residual_lines = build_residual_lines( lines, global_chunk_points, start_idx, self.window_size, recurrent_type ) print('global_chunk_points:', global_chunk_points) result = { 'global_chunk_points': global_chunk_points, 'raw_qa': raw_qa, 'error_count': error_count, } return result def inference(self, document, recurrent_type=1): input_lines, origin_lines = self.pre_process(document) chunked_result = self.iterative_inf(input_lines, recurrent_type=recurrent_type) chunks = self.post_process(origin_lines, chunked_result['global_chunk_points']) chunked_document = '\n'.join(['#'*c[1] + ' ' + c[0] for c in chunks]) return chunked_document, chunks