# pip install datasets pyarrow regex import re import random from datasets import load_dataset # ========= 正则 ========= SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S) TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S) # 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明: NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$") in_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" # 输入 out_path = "/home/data/raw/test/4201_2355_full_label_1000-8192_sys3round.parquet" # 输出(改名以示区分) def join_with_clean_gap(system_block: str, body: str) -> str: """ 保留 system 段原样不动。 保证 system 段与后续对话之间至少有一个换行。 """ if not system_block: return body if system_block.endswith("\n"): return system_block + body else: return system_block + "\n" + body def last_4rounds_user_to_open_assistant(chatml: str) -> str: """ 取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式) 保留最前面的 system 段。 """ if not isinstance(chatml, str): return chatml # 提取 system 段(原样) m_sys = re.match(SYS_HEAD, chatml) system_block = m_sys.group(0) if m_sys else "" # 去掉 system 后的对话部分 text = SYS_HEAD.sub("", chatml) # 非 ChatML 就保守返回(带上 system) if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text): return join_with_clean_gap(system_block, text) # 找到最后一次 assistant(开放式起点) last_ast = text.rfind("<|im_start|>assistant") if last_ast == -1: return join_with_clean_gap(system_block, text.strip()) # 开放式 assistant:去掉它后面的 <|im_end|> 及其后续 final_assistant_open = text[last_ast:] final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S) # 在开放式之前收集闭合轮次 head = text[:last_ast] turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)] # 四轮需要 U,A,U,A,U,A,U 共 7 段历史;不足则尽力返回 if len(turns) < 7: prefix = "\n".join(t[1] for t in turns) if prefix: prefix += "\n" body = prefix + final_assistant_open return join_with_clean_gap(system_block, body) # 取以 user 结尾的最近 7 段:U, A, U, A, U, A, U j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None) if j is None: selected = [t[1] for t in turns[-7:]] # 兜底 else: i = max(0, j - 6) # 需要 7 段 => j-6 .. j selected = [t[1] for t in turns[i:j+1]] prefix = ("\n".join(selected) + "\n") if selected else "" body = prefix + final_assistant_open return join_with_clean_gap(system_block, body) def ensure_linebreak_after_assistant(chosen_prompt: str) -> str: """ - <|im_start|>assistant 后必须换行 - 人名: 后面不换行 """ if not isinstance(chosen_prompt, str): return chosen_prompt # 1) 如果 assistant 标签后不是换行,就加换行 chosen_prompt = re.sub( r"(<\|im_start\|>assistant)(?!\s*\n)", r"\1\n", chosen_prompt ) # 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行) m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt) if m: before = m.group(1) name_colon = m.group(2) chosen_prompt = chosen_prompt.replace( before + name_colon + m.group(3), before + name_colon + " " ) return chosen_prompt def _map_fn(ex): cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"]) cp = ensure_linebreak_after_assistant(cp) ex["chosen_prompt"] = cp return ex # ============ 批处理 + 保存 + 抽样打印 ============ ds = load_dataset("parquet", data_files=in_path, split="train") # # 只保留三列 # keep_cols = ["chosen_prompt", "chosen", "reject"] # drop_cols = [c for c in ds.column_names if c not in keep_cols] # if drop_cols: # ds = ds.remove_columns(drop_cols) # 可用 num_proc=4~8 加速(注意内存) ds = ds.map(_map_fn, desc="Keep system + last 4 rounds (open assistant) + linebreak rules") # 保存到新 parquet 文件 ds.to_parquet(out_path) print(f"✅ Saved -> {out_path}") # 抽样打印 5 条 idxs = random.sample(range(len(ds)), min(5, len(ds))) sampled = ds.select(idxs) for i, ex in enumerate(sampled): print(f"===== Sample {i+1} / chosen_prompt 原样 =====") print(ex["chosen_prompt"]) print(f"===== Sample {i+1} / chosen_prompt + chosen =====") print(ex["chosen_prompt"] + ex["chosen"]) print(f"===== Sample {i+1} / chosen_prompt + reject =====") print(ex["chosen_prompt"] + ex["reject"]) print()