rm_code / sys_3round.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
# pip install datasets pyarrow regex
import re
import random
from datasets import load_dataset
# ========= 正则 =========
SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S)
TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S)
# 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明:
NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$")
in_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" # 输入
out_path = "/home/data/raw/test/4201_2355_full_label_1000-8192_sys3round.parquet" # 输出(改名以示区分)
def join_with_clean_gap(system_block: str, body: str) -> str:
"""
保留 system 段原样不动。
保证 system 段与后续对话之间至少有一个换行。
"""
if not system_block:
return body
if system_block.endswith("\n"):
return system_block + body
else:
return system_block + "\n" + body
def last_4rounds_user_to_open_assistant(chatml: str) -> str:
"""
取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式)
保留最前面的 system 段。
"""
if not isinstance(chatml, str):
return chatml
# 提取 system 段(原样)
m_sys = re.match(SYS_HEAD, chatml)
system_block = m_sys.group(0) if m_sys else ""
# 去掉 system 后的对话部分
text = SYS_HEAD.sub("", chatml)
# 非 ChatML 就保守返回(带上 system)
if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text):
return join_with_clean_gap(system_block, text)
# 找到最后一次 assistant(开放式起点)
last_ast = text.rfind("<|im_start|>assistant")
if last_ast == -1:
return join_with_clean_gap(system_block, text.strip())
# 开放式 assistant:去掉它后面的 <|im_end|> 及其后续
final_assistant_open = text[last_ast:]
final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S)
# 在开放式之前收集闭合轮次
head = text[:last_ast]
turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)]
# 四轮需要 U,A,U,A,U,A,U 共 7 段历史;不足则尽力返回
if len(turns) < 7:
prefix = "\n".join(t[1] for t in turns)
if prefix:
prefix += "\n"
body = prefix + final_assistant_open
return join_with_clean_gap(system_block, body)
# 取以 user 结尾的最近 7 段:U, A, U, A, U, A, U
j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None)
if j is None:
selected = [t[1] for t in turns[-7:]] # 兜底
else:
i = max(0, j - 6) # 需要 7 段 => j-6 .. j
selected = [t[1] for t in turns[i:j+1]]
prefix = ("\n".join(selected) + "\n") if selected else ""
body = prefix + final_assistant_open
return join_with_clean_gap(system_block, body)
def ensure_linebreak_after_assistant(chosen_prompt: str) -> str:
"""
- <|im_start|>assistant 后必须换行
- 人名: 后面不换行
"""
if not isinstance(chosen_prompt, str):
return chosen_prompt
# 1) 如果 assistant 标签后不是换行,就加换行
chosen_prompt = re.sub(
r"(<\|im_start\|>assistant)(?!\s*\n)",
r"\1\n",
chosen_prompt
)
# 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行)
m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt)
if m:
before = m.group(1)
name_colon = m.group(2)
chosen_prompt = chosen_prompt.replace(
before + name_colon + m.group(3),
before + name_colon + " "
)
return chosen_prompt
def _map_fn(ex):
cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"])
cp = ensure_linebreak_after_assistant(cp)
ex["chosen_prompt"] = cp
return ex
# ============ 批处理 + 保存 + 抽样打印 ============
ds = load_dataset("parquet", data_files=in_path, split="train")
# # 只保留三列
# keep_cols = ["chosen_prompt", "chosen", "reject"]
# drop_cols = [c for c in ds.column_names if c not in keep_cols]
# if drop_cols:
# ds = ds.remove_columns(drop_cols)
# 可用 num_proc=4~8 加速(注意内存)
ds = ds.map(_map_fn, desc="Keep system + last 4 rounds (open assistant) + linebreak rules")
# 保存到新 parquet 文件
ds.to_parquet(out_path)
print(f"✅ Saved -> {out_path}")
# 抽样打印 5 条
idxs = random.sample(range(len(ds)), min(5, len(ds)))
sampled = ds.select(idxs)
for i, ex in enumerate(sampled):
print(f"===== Sample {i+1} / chosen_prompt 原样 =====")
print(ex["chosen_prompt"])
print(f"===== Sample {i+1} / chosen_prompt + chosen =====")
print(ex["chosen_prompt"] + ex["chosen"])
print(f"===== Sample {i+1} / chosen_prompt + reject =====")
print(ex["chosen_prompt"] + ex["reject"])
print()