# pip install datasets pyarrow regex import re import random from datasets import load_dataset # ========= 正则 ========= SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S) TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S) # 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明: NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$") def last_3rounds_user_to_open_assistant(chatml: str) -> str: """ 取最近三轮:user → assistant → user → assistant → user → assistant(开放式) 去掉最前面的 system 段。 """ if not isinstance(chatml, str): return chatml text = SYS_HEAD.sub("", chatml) # 非 ChatML 就保守返回 if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text): return text # 找到最后一次 assistant(开放式起点) last_ast = text.rfind("<|im_start|>assistant") if last_ast == -1: return text.strip() # 开放式 assistant:去掉它后面的 <|im_end|> 及其后续 final_assistant_open = text[last_ast:] final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S) # 在开放式之前收集闭合轮次 head = text[:last_ast] turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)] if len(turns) < 5: # 历史不足三轮:尽力返回 + 开放式 prefix = "\n".join(t[1] for t in turns) if prefix: prefix += "\n" return prefix + final_assistant_open # 取以 user 结尾的最近 5 段:U, A, U, A, U j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None) if j is None: selected = [t[1] for t in turns[-5:]] # 兜底 else: i = max(0, j - 4) selected = [t[1] for t in turns[i:j+1]] prefix = ("\n".join(selected) + "\n") if selected else "" return prefix + final_assistant_open # ============ 批处理 + 抽样打印 ============ in_path = "/home/data/train_v3full.parquet" # 输入 out_path = "/home/data/train_2round.parquet" # 输出 ds = load_dataset("parquet", data_files=in_path, split="train") # 只保留三列 keep_cols = ["chosen_prompt", "chosen", "reject"] drop_cols = [c for c in ds.column_names if c not in keep_cols] if drop_cols: ds = ds.remove_columns(drop_cols) def ensure_linebreak_after_assistant(chosen_prompt: str) -> str: """ - <|im_start|>assistant 后必须换行 - 人名: 后面不换行 """ # 1) 如果 assistant 标签后不是换行,就加换行 chosen_prompt = re.sub( r"(<\|im_start\|>assistant)(?!\s*\n)", # 后面不是换行 r"\1\n", chosen_prompt ) # 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行) m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt) if m: before = m.group(1) name_colon = m.group(2) chosen_prompt = chosen_prompt.replace( before + name_colon + m.group(3), before + name_colon + " " ) return chosen_prompt def _map_fn(ex): cp = last_3rounds_user_to_open_assistant(ex["chosen_prompt"]) cp = ensure_linebreak_after_assistant(cp) ex["chosen_prompt"] = cp return ex # 可用 num_proc=4~8 加速(注意内存) ds = ds.map(_map_fn, desc="Build last 3 rounds (open assistant) + linebreak rules") ds.to_parquet(out_path) print(f"✅ Saved -> {out_path}") # 抽样打印 5 条(原样 + 拼接效果,便于检查是否多空行/人名是否同一行) idxs = random.sample(range(len(ds)), min(5, len(ds))) sampled = ds.select(idxs) for i, ex in enumerate(sampled): print(f"===== Sample {i+1} / chosen_prompt 原样 =====") print(ex["chosen_prompt"]) print(f"===== Sample {i+1} / chosen_prompt + chosen =====") print(ex["chosen_prompt"] + ex["chosen"]) print(f"===== Sample {i+1} / chosen_prompt + reject =====") print(ex["chosen_prompt"] + ex["reject"]) print()