| | |
| | import re |
| | import random |
| | from datasets import load_dataset |
| |
|
| | |
| | SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S) |
| | TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S) |
| |
|
| | |
| | NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$") |
| | in_path = "/home/data/raw/test/4201_2355_full_label_1000-8192.parquet" |
| | out_path = "/home/data/raw/test/4201_2355_full_label_1000-8192_sys3round.parquet" |
| | def join_with_clean_gap(system_block: str, body: str) -> str: |
| | """ |
| | 保留 system 段原样不动。 |
| | 保证 system 段与后续对话之间至少有一个换行。 |
| | """ |
| | if not system_block: |
| | return body |
| | if system_block.endswith("\n"): |
| | return system_block + body |
| | else: |
| | return system_block + "\n" + body |
| |
|
| | def last_4rounds_user_to_open_assistant(chatml: str) -> str: |
| | """ |
| | 取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式) |
| | 保留最前面的 system 段。 |
| | """ |
| | if not isinstance(chatml, str): |
| | return chatml |
| |
|
| | |
| | m_sys = re.match(SYS_HEAD, chatml) |
| | system_block = m_sys.group(0) if m_sys else "" |
| |
|
| | |
| | text = SYS_HEAD.sub("", chatml) |
| |
|
| | |
| | if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text): |
| | return join_with_clean_gap(system_block, text) |
| |
|
| | |
| | last_ast = text.rfind("<|im_start|>assistant") |
| | if last_ast == -1: |
| | return join_with_clean_gap(system_block, text.strip()) |
| |
|
| | |
| | final_assistant_open = text[last_ast:] |
| | final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S) |
| |
|
| | |
| | head = text[:last_ast] |
| | turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)] |
| |
|
| | |
| | if len(turns) < 7: |
| | prefix = "\n".join(t[1] for t in turns) |
| | if prefix: |
| | prefix += "\n" |
| | body = prefix + final_assistant_open |
| | return join_with_clean_gap(system_block, body) |
| |
|
| | |
| | j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None) |
| | if j is None: |
| | selected = [t[1] for t in turns[-7:]] |
| | else: |
| | i = max(0, j - 6) |
| | selected = [t[1] for t in turns[i:j+1]] |
| |
|
| | prefix = ("\n".join(selected) + "\n") if selected else "" |
| | body = prefix + final_assistant_open |
| | return join_with_clean_gap(system_block, body) |
| |
|
| | def ensure_linebreak_after_assistant(chosen_prompt: str) -> str: |
| | """ |
| | - <|im_start|>assistant 后必须换行 |
| | - 人名: 后面不换行 |
| | """ |
| | if not isinstance(chosen_prompt, str): |
| | return chosen_prompt |
| |
|
| | |
| | chosen_prompt = re.sub( |
| | r"(<\|im_start\|>assistant)(?!\s*\n)", |
| | r"\1\n", |
| | chosen_prompt |
| | ) |
| |
|
| | |
| | m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt) |
| | if m: |
| | before = m.group(1) |
| | name_colon = m.group(2) |
| | chosen_prompt = chosen_prompt.replace( |
| | before + name_colon + m.group(3), |
| | before + name_colon + " " |
| | ) |
| | return chosen_prompt |
| |
|
| | def _map_fn(ex): |
| | cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"]) |
| | cp = ensure_linebreak_after_assistant(cp) |
| | ex["chosen_prompt"] = cp |
| | return ex |
| |
|
| | |
| |
|
| | ds = load_dataset("parquet", data_files=in_path, split="train") |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | ds = ds.map(_map_fn, desc="Keep system + last 4 rounds (open assistant) + linebreak rules") |
| |
|
| | |
| | ds.to_parquet(out_path) |
| | print(f"✅ Saved -> {out_path}") |
| |
|
| | |
| | idxs = random.sample(range(len(ds)), min(5, len(ds))) |
| | sampled = ds.select(idxs) |
| | for i, ex in enumerate(sampled): |
| | print(f"===== Sample {i+1} / chosen_prompt 原样 =====") |
| | print(ex["chosen_prompt"]) |
| | print(f"===== Sample {i+1} / chosen_prompt + chosen =====") |
| | print(ex["chosen_prompt"] + ex["chosen"]) |
| | print(f"===== Sample {i+1} / chosen_prompt + reject =====") |
| | print(ex["chosen_prompt"] + ex["reject"]) |
| | print() |
| |
|