File size: 3,427 Bytes
ab35c3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
数据预处理脚本。
步骤:
1. 读取原始平行语料(JSONL)。
2. 构建 / 保存词表(调用统一 Tokenizer 接口)。
3. 将句子转成 id 序列并写回 processed/*.jsonl。
"""
import argparse
import importlib
import json
import os
import pickle
from pathlib import Path
from typing import Any, Dict, List

import yaml
from tqdm.auto import tqdm

from tokenizer import BaseTokenizer, JiebaEnTokenizer


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config",
        type=str,
        default="config.yaml",
        help="yaml 配置文件路径",
    )
    return parser.parse_args()


def load_config(path: str | Path) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def load_raw_corpus(cfg: Dict[str, Any]) -> tuple[list[dict], list[dict], list[dict]]:
    def read_jsonl(p: str) -> list[dict]:
        with open(p, "r", encoding="utf-8") as f:
            return [json.loads(line) for line in f]

    return (
        read_jsonl(cfg["data"]["raw_train"]),
        read_jsonl(cfg["data"]["raw_val"]),
        read_jsonl(cfg["data"]["raw_test"]),
    )


def instantiate_tokenizer(cfg: Dict[str, Any]) -> BaseTokenizer:
    """
    根据 config[tokenizer] 动态加载,默认使用 JiebaEnTokenizer。
    写法示例:
        tokenizer: my_pkg.my_tok.MyTokenizer
    """
    tok_path = cfg.get("tokenizer", "tokenizer.JiebaEnTokenizer")
    mod_name, cls_name = tok_path.rsplit(".", 1)
    TokCls: type[BaseTokenizer] = getattr(importlib.import_module(mod_name), cls_name)
    return TokCls()


def make_processed_dirs(cfg: Dict[str, Any]) -> None:
    Path(cfg["data"]["processed_dir"]).mkdir(parents=True, exist_ok=True)


def save_vocab(tokenizer: BaseTokenizer, cfg: Dict[str, Any]) -> None:
    with open(cfg["data"]["src_vocab"], "wb") as f:
        pickle.dump(tokenizer.src_vocab, f)
    with open(cfg["data"]["tgt_vocab"], "wb") as f:
        pickle.dump(tokenizer.tgt_vocab, f)


def encode_and_save(
    dataset: list[dict],
    out_path: str | Path,
    tokenizer: BaseTokenizer,
) -> None:
    """
    将一批样本编码后写为 jsonl,每行格式:
        {"src_ids":[...], "tgt_ids":[...]}
    """
    with open(out_path, "w", encoding="utf-8") as fout:
        for sample in tqdm(dataset, desc=f"writing {out_path}"):
            src_ids = tokenizer.encode_src(sample["zh"])
            tgt_ids = tokenizer.encode_tgt(sample["en"])
            json.dump({"src_ids": src_ids, "tgt_ids": tgt_ids}, fout, ensure_ascii=False)
            fout.write("\n")


def main() -> None:
    args = parse_args()
    cfg = load_config(args.config)

    make_processed_dirs(cfg)
    tokenizer = instantiate_tokenizer(cfg)

    train_raw, val_raw, test_raw = load_raw_corpus(cfg)

    # ---------- 构建词表 ----------
    tokenizer.build_vocab(
        [s["zh"] for s in train_raw],
        [s["en"] for s in train_raw],
        min_freq=cfg["data"].get("min_freq", 2),
    )
    save_vocab(tokenizer, cfg)

    # ---------- 编码并保存 ----------
    encode_and_save(train_raw, cfg["data"]["train_processed"], tokenizer)
    encode_and_save(val_raw, cfg["data"]["val_processed"], tokenizer)
    encode_and_save(test_raw, cfg["data"]["test_processed"], tokenizer)

    print("预处理完成!")


if __name__ == "__main__":
    main()