# pip install datasets pyarrow regex import re import random from datasets import load_dataset # ========= 正则 ========= # 匹配最前面的 system 段(若存在) SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S) # 匹配闭合的 user/assistant 轮次 TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S) def join_with_clean_gap(system_block: str, body: str) -> str: """ 保留 system 段原样不动。 只保证 system 段与后续对话之间至少有一个换行: - 若 system_block 末尾已有换行:不改 - 若没有:在末尾补一个 '\n' """ if not system_block: return body if system_block.endswith("\n"): return system_block + body else: return system_block + "\n" + body def last_3rounds_user_to_open_assistant(chatml: str) -> str: """ 取最近三轮:user → assistant → user → assistant → user → assistant(开放式) 保留最前面的 system 段(原样)。 """ if not isinstance(chatml, str): return chatml # 提取 system 段(原样) m_sys = re.match(SYS_HEAD, chatml) system_block = m_sys.group(0) if m_sys else "" # 去掉 system 后的对话部分 text = SYS_HEAD.sub("", chatml) # 非 ChatML 就保守返回(带上 system) if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text): return join_with_clean_gap(system_block, text) # 找到最后一次 assistant(开放式起点) last_ast = text.rfind("<|im_start|>assistant") if last_ast == -1: return join_with_clean_gap(system_block, text.strip()) # 开放式 assistant:去掉它后面的 <|im_end|> 及其后续 final_assistant_open = text[last_ast:] final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S) # 在开放式之前收集闭合轮次 head = text[:last_ast] turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)] if len(turns) < 5: prefix = "\n".join(t[1] for t in turns) if prefix: prefix += "\n" body = prefix + final_assistant_open return join_with_clean_gap(system_block, body) # 取以 user 结尾的最近 5 段:U, A, U, A, U j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None) if j is None: selected = [t[1] for t in turns[-5:]] else: i = max(0, j - 4) selected = [t[1] for t in turns[i:j+1]] prefix = ("\n".join(selected) + "\n") if selected else "" body = prefix + final_assistant_open return join_with_clean_gap(system_block, body) def ensure_linebreak_after_assistant(chosen_prompt: str) -> str: """ - <|im_start|>assistant 后必须换行 - 如果是「人名:」紧跟其后,保证人名和内容在同一行(人名: 后保留一个空格) """ if not isinstance(chosen_prompt, str): return chosen_prompt # 如果 assistant 标签后不是换行,就加换行 chosen_prompt = re.sub( r"(<\|im_start\|>assistant)(?!\s*\n)", r"\1\n", chosen_prompt ) # 如果是人名: 后面有换行,就去掉换行 m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt) if m: before = m.group(1) name_colon = m.group(2) chosen_prompt = chosen_prompt.replace( before + name_colon + m.group(3), before + name_colon + " " ) return chosen_prompt def _map_fn(ex): cp = last_3rounds_user_to_open_assistant(ex["chosen_prompt"]) cp = ensure_linebreak_after_assistant(cp) ex["chosen_prompt"] = cp return ex # ============ 批处理 + 抽样打印 ============ in_path = "/home/data/train_v3full.parquet" # 输入 out_path = "/home/data/train_sys_2round.parquet" # 输出 ds = load_dataset("parquet", data_files=in_path, split="train") # 只保留三列 keep_cols = ["chosen_prompt", "chosen", "reject"] drop_cols = [c for c in ds.column_names if c not in keep_cols] if drop_cols: ds = ds.remove_columns(drop_cols) # 可用 num_proc=4~8 加速(注意内存) ds = ds.map(_map_fn, desc="Keep system + last 3 rounds (open assistant) + linebreak rules") # 保存 ds.to_parquet(out_path) print(f"✅ Saved -> {out_path}") # 抽样打印 idxs = random.sample(range(len(ds)), min(5, len(ds))) sampled = ds.select(idxs) for i, ex in enumerate(sampled): print(f"===== Sample {i+1} / chosen_prompt 原样 =====") print(ex["chosen_prompt"]) print(f"===== Sample {i+1} / chosen_prompt + chosen =====") print(ex["chosen_prompt"] + ex["chosen"]) print(f"===== Sample {i+1} / chosen_prompt + reject =====") print(ex["chosen_prompt"] + ex["reject"]) print()