File size: 5,038 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | # pip install datasets pyarrow regex
import re
import random
from datasets import load_dataset
# ========= 正则 =========
SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S)
TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S)
# 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明:
NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$")
in_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" # 输入
out_path = "/home/data/raw/test/4201_2355_full_label_1000-8192_sys3round.parquet" # 输出(改名以示区分)
def join_with_clean_gap(system_block: str, body: str) -> str:
"""
保留 system 段原样不动。
保证 system 段与后续对话之间至少有一个换行。
"""
if not system_block:
return body
if system_block.endswith("\n"):
return system_block + body
else:
return system_block + "\n" + body
def last_4rounds_user_to_open_assistant(chatml: str) -> str:
"""
取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式)
保留最前面的 system 段。
"""
if not isinstance(chatml, str):
return chatml
# 提取 system 段(原样)
m_sys = re.match(SYS_HEAD, chatml)
system_block = m_sys.group(0) if m_sys else ""
# 去掉 system 后的对话部分
text = SYS_HEAD.sub("", chatml)
# 非 ChatML 就保守返回(带上 system)
if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text):
return join_with_clean_gap(system_block, text)
# 找到最后一次 assistant(开放式起点)
last_ast = text.rfind("<|im_start|>assistant")
if last_ast == -1:
return join_with_clean_gap(system_block, text.strip())
# 开放式 assistant:去掉它后面的 <|im_end|> 及其后续
final_assistant_open = text[last_ast:]
final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S)
# 在开放式之前收集闭合轮次
head = text[:last_ast]
turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)]
# 四轮需要 U,A,U,A,U,A,U 共 7 段历史;不足则尽力返回
if len(turns) < 7:
prefix = "\n".join(t[1] for t in turns)
if prefix:
prefix += "\n"
body = prefix + final_assistant_open
return join_with_clean_gap(system_block, body)
# 取以 user 结尾的最近 7 段:U, A, U, A, U, A, U
j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None)
if j is None:
selected = [t[1] for t in turns[-7:]] # 兜底
else:
i = max(0, j - 6) # 需要 7 段 => j-6 .. j
selected = [t[1] for t in turns[i:j+1]]
prefix = ("\n".join(selected) + "\n") if selected else ""
body = prefix + final_assistant_open
return join_with_clean_gap(system_block, body)
def ensure_linebreak_after_assistant(chosen_prompt: str) -> str:
"""
- <|im_start|>assistant 后必须换行
- 人名: 后面不换行
"""
if not isinstance(chosen_prompt, str):
return chosen_prompt
# 1) 如果 assistant 标签后不是换行,就加换行
chosen_prompt = re.sub(
r"(<\|im_start\|>assistant)(?!\s*\n)",
r"\1\n",
chosen_prompt
)
# 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行)
m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt)
if m:
before = m.group(1)
name_colon = m.group(2)
chosen_prompt = chosen_prompt.replace(
before + name_colon + m.group(3),
before + name_colon + " "
)
return chosen_prompt
def _map_fn(ex):
cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"])
cp = ensure_linebreak_after_assistant(cp)
ex["chosen_prompt"] = cp
return ex
# ============ 批处理 + 保存 + 抽样打印 ============
ds = load_dataset("parquet", data_files=in_path, split="train")
# # 只保留三列
# keep_cols = ["chosen_prompt", "chosen", "reject"]
# drop_cols = [c for c in ds.column_names if c not in keep_cols]
# if drop_cols:
# ds = ds.remove_columns(drop_cols)
# 可用 num_proc=4~8 加速(注意内存)
ds = ds.map(_map_fn, desc="Keep system + last 4 rounds (open assistant) + linebreak rules")
# 保存到新 parquet 文件
ds.to_parquet(out_path)
print(f"✅ Saved -> {out_path}")
# 抽样打印 5 条
idxs = random.sample(range(len(ds)), min(5, len(ds)))
sampled = ds.select(idxs)
for i, ex in enumerate(sampled):
print(f"===== Sample {i+1} / chosen_prompt 原样 =====")
print(ex["chosen_prompt"])
print(f"===== Sample {i+1} / chosen_prompt + chosen =====")
print(ex["chosen_prompt"] + ex["chosen"])
print(f"===== Sample {i+1} / chosen_prompt + reject =====")
print(ex["chosen_prompt"] + ex["reject"])
print()
|