Youtu-HiChunk / HiChunk.py
Luuuuk's picture
Upload folder using huggingface_hub
704323c verified
import argparse
import os
import re
import time
import openai
import requests
from nltk.tokenize import sent_tokenize
def replace_jinhao(line, replacement=None):
if replacement is not None and re.match(r'^( *#*)*', line)[0].strip() != '':
# 全部替换为 replacement
return re.sub(r'^( *#*)*', replacement, line, count=1)
else:
return line
def count_jinhao(line):
return re.match(r'^( *#*)*', line)[0].count('#')
def is_english(strs):
for _char in strs:
if '\u4e00' <= _char <= '\u9fa5':
return False
return True
def sentence_split_en(line):
res = filter(lambda l: l.strip() != '', sent_tokenize(line))
res = list(map(lambda l: l.strip(), res))
idx = 0
while idx < len(res) - 1:
if len(res[idx]) < 10: # 句子少于10个字符,向后合并。
res[idx + 1] = res[idx] + ' ' + res[idx + 1]
res.pop(idx)
else:
idx += 1
return res
def sentence_split_zh(line):
res = []
pre_idx = 0
for i in range(1, len(line)-1):
if line[i] != '。': # 按句号切分
continue
if line[i - 1] in '0123456789': # ocr error
continue
if line[i + 1] in '0123456789': # ocr error
continue
if len(line[pre_idx: i + 1].strip()) <= 5:
continue
res.append(line[pre_idx: i + 1].strip())
pre_idx = i + 1
if pre_idx < len(line):
res.append(line[pre_idx:])
return res
def sentence_split(line):
if is_english(line):
return sentence_split_en(line)
else:
return sentence_split_zh(line)
def sentence_truncation(line, head_limit=15, tail_limit=15):
total_limit = head_limit+tail_limit
if is_english(line):
len_factor = 10
else:
len_factor = 1
if 0 < total_limit * len_factor < len(line):
_head_limit = head_limit * len_factor
_tail_limit = len(line) - tail_limit * len_factor
line = line[:_head_limit] + line[_tail_limit:]
return line
def text2sentence(lines, replacement=None, head_limit=15, tail_limit=15):
"""
:param lines:
:param replacement: in [None, '# ', '']. replace the jinhao prefix of one line. None means no replacement.
:param head_limit:
:param tail_limit:
:return:
"""
res = []
for idx, line in enumerate(lines):
res.extend(sentence_split(line))
for idx, temp in enumerate(res):
_temp = replace_jinhao(temp, '# ')
_temp = sentence_truncation(_temp, head_limit, tail_limit)
_temp = replace_jinhao(_temp, f"{'#'*count_jinhao(temp)} ")
_temp = replace_jinhao(_temp, replacement)
res[idx] = _temp+'\n'
return res
PROMPT = ('You are an assistant good at reading and formatting documents, and you are also skilled at distinguishing '
'the semantic and logical relationships of sentences between document context. The following is a text that '
'has already been divided into sentences. Each line is formatted as: "{line number} @ {sentence content}". '
'You need to segment this text based on semantics and format. There are multiple levels of granularity for '
'segmentation, the higher level number means the finer granularity of the segmentation. Please ensure that '
'each Level One segment is semantically complete after segmentation. A Level One segment may contain '
'multiple Level Two segments, and so on. Please incrementally output the starting line numbers of each level '
'of segments, and determine the level of the segment, as well as whether the content of the sentence at the '
'starting line number can be used as the title of the segment. Finally, output a list format result, '
'where each element is in the format of: "{line number}, {segment level}, {be a title?}".'
'\n\n>>> Input text:\n')
def index_format(idx, line):
return f'{idx} @ {line}'
def points2clip(points, start_idx, end_idx):
"""
:param points: [a, b, c, d]
:param start_idx: x
:param end_idx: y
assert: x <= a < b < c < d < y
return [[x, a], [a, b], [b, c], [c, d], [d, y]]
"""
clips = []
pre_p = start_idx
for p in points:
if p == start_idx or p >= end_idx:
continue
clips.append([pre_p, p])
pre_p = p
clips.append([pre_p, end_idx])
return clips
# parse answer string to list of chunking points
def parse_answer_chunking_point(answer_string, max_level):
local_chunk_points = {level_dict_en[i]: [] for i in range(max_level)}
for line in answer_string.split('\n'):
[point, level, _] = line.split(', ')
if level in local_chunk_points:
local_chunk_points[level].append(int(point))
res = list(local_chunk_points.values())
for idx, _ in enumerate(res):
if len(_) == 0:
continue
keep_idx = list(filter(lambda i: _[i] > _[i-1], range(1, len(_))))
res[idx] = [_[0]] + list(map(lambda i: _[i], keep_idx))
return res
level_dict_en = {
0: 'Level One',
1: 'Level Two',
2: 'Level Three',
3: 'Level Four',
4: 'Level Five',
5: 'Level Six',
6: 'Level Seven',
7: 'Level Eight',
8: 'Level Nine',
9: 'Level Ten',
}
def check_answer_point(first_level_points, start_idx, end_idx):
print('parsed_answer:', first_level_points, start_idx, end_idx)
if len(first_level_points) > 0 and first_level_points[0] < start_idx:
return False
for idx in range(1, len(first_level_points)):
p = first_level_points[idx]
if p <= first_level_points[idx-1] or p > end_idx:
return False
return True
def build_residual_lines(lines, global_chunk_points, start_idx, window_size, recurrent_type):
if recurrent_type in [0, 1]:
return []
assert recurrent_type == 2, f'Not implemented for recurrent_type: {recurrent_type}'
last_first_point = 0
if len(global_chunk_points[0]) > 0:
last_first_point = global_chunk_points[0][-1]
current_second_points = filter(lambda p: p >= last_first_point, global_chunk_points[1])
temp_second_clips = points2clip(current_second_points, last_first_point, start_idx)
# 每个一级片段中,最多保留5个二级片段,前2后3,每个二级片段最多20行。
pre_seg_num, post_seg_num, line_num = 2, 3, 20
while True:
residual_second_clips = temp_second_clips
if len(temp_second_clips) > (pre_seg_num + post_seg_num):
residual_second_clips = (
temp_second_clips[:pre_seg_num] + temp_second_clips[len(temp_second_clips)-post_seg_num:]
)
residual_lines = []
for rsc in residual_second_clips:
# 每个二级片段最多保留20行
pre_sent_idx, post_sent_idx = rsc[0], min(rsc[1], rsc[0]+line_num)
residual_lines.extend(lines[pre_sent_idx: post_sent_idx])
if len('\n'.join(residual_lines)) < window_size/2:
print(residual_lines)
return residual_lines
# 超出推理窗口一半,则需要减少残余输入。前减1,后减1,行数减5。
pre_seg_num, post_seg_num, line_num = pre_seg_num-1, post_seg_num-1, line_num-5
# 最小设置的情况下仍然超出窗口一半,则不添加残余输入。
if pre_seg_num * post_seg_num * line_num <= 0:
return []
def union_chunk_points(local_chunk_points, global_chunk_points, max_idx):
for idx, level in enumerate(global_chunk_points):
global_chunk_points[idx].extend(filter(lambda p: p < max_idx, local_chunk_points[idx]))
return global_chunk_points
class HiChunkInferenceEngine:
def __init__(self, window_size, line_max_len, max_level, prompt):
self.window_size = window_size
self.line_max_len = line_max_len
self.max_level = max_level
self.prompt = prompt
self.base_url = os.environ.get("OPENAI_BASE_URL", "http://localhost:8000")
self.llm = openai.Client(base_url=f"{self.base_url}/v1", api_key="[empty]")
def init_chunk_points(self):
global_chunk_points = [[] for i in range(self.max_level)]
return global_chunk_points
def build_input_instruction(self, prompt, global_start_idx, sentences, window_size, residual_lines=None):
"""
Build input instruction for once inference
:param prompt: prompt
:param global_start_idx: global start index of input sentences
:param sentences: global sentences
:param window_size:
:param residual_lines:
:return:
"""
q = prompt
# concat residual lines if exists
residual_index = 0
while residual_lines is not None and residual_index < len(residual_lines):
line_text = index_format(residual_index, residual_lines[residual_index])
temp_text = q + line_text
q = temp_text
residual_index += 1
assert self.count_length(q) <= window_size, 'residual lines exceeds window size'
local_start_idx = 0
cur_token_num = self.count_length(q)
end = False
# concat sentences until reach window_size
while global_start_idx < len(sentences):
line_text = index_format(local_start_idx + residual_index, sentences[global_start_idx])
temp_text = q + line_text
line_token_num = self.count_length(line_text)
if cur_token_num + line_token_num > window_size:
break
cur_token_num += line_token_num
q = temp_text
local_start_idx += 1
global_start_idx += 1
if global_start_idx == len(sentences):
end = True
return q, end, local_start_idx
def call_llm(self, input_text):
response = self.llm.chat.completions.create(
model='HiChunk',
messages=[{'role': 'user', 'content': input_text}],
temperature=0.0,
max_tokens=4096,
extra_body={
"chat_template_kwargs": {"add_generation_prompt": True, "enable_thinking": False}
}
)
return response.choices[0].message.content
def count_length(self, text):
response = requests.post(
url=f'{self.base_url}/tokenize',
json={'model': 'HiChunk', 'prompt': text}
).json()
return response['count']
def pre_process(self, document):
lines = map(lambda l: l.strip(), document.split('\n'))
lines = list(filter(lambda l: len(l) != 0, lines))
origin_lines = text2sentence(lines, None, -1, 0) # 原始文本行,不截断
input_lines = text2sentence(lines, '', self.line_max_len, 0) # 原始文本行,截断长度为self.line_max_len
return input_lines, origin_lines
@staticmethod
def post_process(origin_lines, global_chunk_points):
origin_lines_remove_jinhao = [replace_jinhao(l, '') for l in origin_lines]
total_points = sorted(
[[__, i + 1] for i, _ in enumerate(global_chunk_points) for __ in _],
key=lambda p: p[0]
)
splits = []
pre_level, pre_point = 1, 0
for i, [p, level] in enumerate(total_points):
if p == 0:
continue
splits.append([''.join(origin_lines_remove_jinhao[pre_point: p]), pre_level])
pre_level = level
pre_point = p
splits.append([''.join(origin_lines_remove_jinhao[pre_point:]), pre_level])
return splits
def iterative_inf(self, lines, recurrent_type=1):
error_count, start_idx = 0, 0
raw_qa, residual_lines = [], []
global_chunk_points = self.init_chunk_points()
while start_idx < len(lines):
residual_sent_num = len(residual_lines)
question, is_end, question_sent_num = self.build_input_instruction(
self.prompt, start_idx, lines, self.window_size, residual_lines
)
question_token_num = self.count_length(question)
print('question len:', len(question), question_token_num)
start_time = time.time()
answer = self.call_llm(question)
inf_time = time.time() - start_time
answer_token_num = self.count_length(answer)
print('answer:', answer)
print('answer len:', answer_token_num)
tmp = {
'question': question, 'answer': answer, 'start_idx': start_idx, 'end_idx': start_idx+question_sent_num,
'residual_sent_num': residual_sent_num, 'time': inf_time,
'question_token_num': question_token_num, 'answer_token_num': answer_token_num,
}
# 解析输出结果,将局部句子序号转化为全局的句子序号
try:
local_chunk_points = parse_answer_chunking_point(answer, self.max_level)
if not check_answer_point(local_chunk_points[0], 0, question_sent_num+residual_sent_num-1):
print('###########check error##############')
tmp['status'] = 'check error'
local_chunk_points = self.init_chunk_points()
local_chunk_points[0].append(start_idx)
error_count += 1
else:
tmp['status'] = 'check ok'
print('#############check ok################')
for idx, points in enumerate(local_chunk_points):
filter_points = filter(lambda p: p >= residual_sent_num, points)
# p-residual_sent_num+start_idx 将局部推理的句子序号转化为全局文档的句子序号
local_chunk_points[idx] = [p - residual_sent_num + start_idx for p in filter_points]
except:
print('##########parsed error################')
tmp['status'] = 'parse error'
local_chunk_points = self.init_chunk_points()
local_chunk_points[0].append(start_idx)
error_count += 1
raw_qa.append(tmp)
print('local_chunk_points:', local_chunk_points)
if is_end: # 全文档推理结束
start_idx += question_sent_num
global_chunk_points = union_chunk_points(local_chunk_points, global_chunk_points, start_idx)
break
if len(local_chunk_points[0]) > 1 and recurrent_type in [1, 2]: # 多个一级片段,丢弃掉本次结果的最后一个一级片段
# 从最后一个一级片段开始构建下次迭代的输入
start_idx = local_chunk_points[0][-1]
global_chunk_points = union_chunk_points(local_chunk_points, global_chunk_points, start_idx)
residual_lines = []
else: # 局部推理结果中只有一个一级片段
# 从上次迭代输入的最后一行开始构建下次迭代的输入,并带上上次迭代的残余行
start_idx += question_sent_num
global_chunk_points = union_chunk_points(local_chunk_points, global_chunk_points, start_idx)
residual_lines = build_residual_lines(
lines, global_chunk_points, start_idx, self.window_size, recurrent_type
)
print('global_chunk_points:', global_chunk_points)
result = {
'global_chunk_points': global_chunk_points,
'raw_qa': raw_qa,
'error_count': error_count,
}
return result
def inference(self, document, recurrent_type=1):
input_lines, origin_lines = self.pre_process(document)
chunked_result = self.iterative_inf(input_lines, recurrent_type=recurrent_type)
chunks = self.post_process(origin_lines, chunked_result['global_chunk_points'])
chunked_document = '\n'.join(['#'*c[1] + ' ' + c[0] for c in chunks])
return chunked_document, chunks