Spaces:
Starting on A100
Starting on A100
| """๋ฏผ์๋ต๋ณ ํ์ต ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ.""" | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import logging | |
| import random | |
| from pathlib import Path | |
| from typing import Any | |
| from .config import DataConfig | |
| from .parsers import AdminLawParser, GovQALocalParser, GovQAParser, GukripParser | |
| logger = logging.getLogger(__name__) | |
| class CivilResponseDataPipeline: | |
| """AI Hub ์์ ๋ฐ์ดํฐ๋ฅผ instruction-tuning JSONL๋ก ๋ณํํ๋ ํ์ดํ๋ผ์ธ.""" | |
| INSTRUCTION_TEXT = "๋ค์ ๋ฏผ์์ ๋ํ ๋ต๋ณ์ ์์ฑํด ์ฃผ์ธ์." | |
| def __init__(self, config: DataConfig | None = None): | |
| self.config = config or DataConfig() | |
| def run(self) -> dict[str, int]: | |
| """์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ. ๊ฒฐ๊ณผ ํต๊ณ ๋ฐํ.""" | |
| records: list[dict] = [] | |
| logger.info("71852 ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์์") | |
| records_71852 = self._process_71852() | |
| logger.info("71852 ๋ฐ์ดํฐ %d๊ฐ ์์ง", len(records_71852)) | |
| records.extend(records_71852) | |
| logger.info("71847 ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์์") | |
| records_71847 = self._process_71847() | |
| logger.info("71847 ๋ฐ์ดํฐ %d๊ฐ ์์ง", len(records_71847)) | |
| records.extend(records_71847) | |
| logger.info("์ค๋ณต ์ ๊ฑฐ ์ ์ด %d๊ฐ", len(records)) | |
| records = self._deduplicate(records) | |
| logger.info("์ค๋ณต ์ ๊ฑฐ ํ %d๊ฐ", len(records)) | |
| records = self._filter(records) | |
| logger.info("ํํฐ๋ง ํ %d๊ฐ", len(records)) | |
| train, val = self._split(records) | |
| logger.info("train=%d, val=%d", len(train), len(val)) | |
| output_dir = Path(self.config.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| self._save_jsonl(train, output_dir / "train.jsonl") | |
| self._save_jsonl(val, output_dir / "val.jsonl") | |
| return { | |
| "total": len(records), | |
| "train": len(train), | |
| "val": len(val), | |
| } | |
| # ------------------------------------------------------------------ | |
| # ๋ฐ์ดํฐ์ ๋ณ ์ฒ๋ฆฌ | |
| # ------------------------------------------------------------------ | |
| def _process_71852(self) -> list[dict]: | |
| base = Path(self.config.raw_dir) / "71852" | |
| records: list[dict] = [] | |
| # ๊ตญ๋ฆฝ์์์๋ฌธํ์ ๋น | |
| gukrp = GukripParser() | |
| for split in ("train", "val"): | |
| dir_path = base / split / "๊ตญ๋ฆฝ" | |
| if dir_path.exists(): | |
| records.extend(self._parse_dir(gukrp, dir_path)) | |
| # ์ค์ํ์ ๊ธฐ๊ด | |
| gov_central = GovQAParser() | |
| for split in ("train", "val"): | |
| dir_path = base / split / "์ค์" | |
| if dir_path.exists(): | |
| records.extend(self._parse_dir(gov_central, dir_path)) | |
| # ์ง๋ฐฉํ์ ๊ธฐ๊ด | |
| gov_local = GovQALocalParser() | |
| for split in ("train", "val"): | |
| dir_path = base / split / "์ง๋ฐฉ" | |
| if dir_path.exists(): | |
| records.extend(self._parse_dir(gov_local, dir_path)) | |
| return records | |
| def _process_71847(self) -> list[dict]: | |
| base = Path(self.config.raw_dir) / "71847" | |
| records: list[dict] = [] | |
| # ๊ฒฐ์ ๋ก QA | |
| decision_parser = AdminLawParser(source_label="71847_๊ฒฐ์ ๋ก") | |
| dir_path = base / "TL_๊ฒฐ์ ๋ก_QA" | |
| if dir_path.exists(): | |
| records.extend(self._parse_dir(decision_parser, dir_path)) | |
| # ๋ฒ๋ น QA | |
| law_parser = AdminLawParser(source_label="71847_๋ฒ๋ น") | |
| dir_path = base / "TL_๋ฒ๋ น_QA" | |
| if dir_path.exists(): | |
| records.extend(self._parse_dir(law_parser, dir_path)) | |
| return records | |
| # ------------------------------------------------------------------ | |
| # ์ ํธ๋ฆฌํฐ | |
| # ------------------------------------------------------------------ | |
| def _parse_dir(parser: Any, dir_path: Path) -> list[dict]: | |
| records: list[dict] = [] | |
| json_files = list(dir_path.glob("*.json")) | |
| logger.debug(" %s: %d ํ์ผ", dir_path, len(json_files)) | |
| for filepath in json_files: | |
| try: | |
| records.extend(parser.parse(filepath)) | |
| except Exception as exc: # noqa: BLE001 | |
| logger.warning("ํ์ฑ ์คํจ %s: %s", filepath, exc) | |
| return records | |
| def _deduplicate(self, records: list[dict]) -> list[dict]: | |
| """์ง๋ฌธ+๋ต๋ณ ํด์ ๊ธฐ๋ฐ ์ค๋ณต ์ ๊ฑฐ.""" | |
| seen: set[str] = set() | |
| unique: list[dict] = [] | |
| for rec in records: | |
| key = hashlib.md5( # nosec B324 | |
| (rec["question"] + rec["answer"]).encode("utf-8"), | |
| usedforsecurity=False, | |
| ).hexdigest() | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(rec) | |
| return unique | |
| def _filter(self, records: list[dict]) -> list[dict]: | |
| """๊ธธ์ด ํํฐ๋ง.""" | |
| filtered: list[dict] = [] | |
| for rec in records: | |
| answer_len = len(rec["answer"]) | |
| question_len = len(rec["question"]) | |
| if answer_len < self.config.min_answer_length: | |
| continue | |
| if answer_len > self.config.max_answer_length: | |
| continue | |
| if question_len < self.config.min_question_length: | |
| continue | |
| filtered.append(rec) | |
| return filtered | |
| def _split(self, records: list[dict]) -> tuple[list[dict], list[dict]]: | |
| """train/val ๋ถ๋ฆฌ (์ ํ ํ ๋น์จ ๋ถํ ).""" | |
| shuffled = list(records) | |
| random.seed(42) | |
| random.shuffle(shuffled) | |
| split_idx = int(len(shuffled) * self.config.train_ratio) | |
| return shuffled[:split_idx], shuffled[split_idx:] | |
| def _save_jsonl(self, records: list[dict], filepath: Path) -> None: | |
| """Instruction-tuning ํ์ค JSONL ํ์์ผ๋ก ์ ์ฅ.""" | |
| filepath = Path(filepath) | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| for rec in records: | |
| line = { | |
| "instruction": self.INSTRUCTION_TEXT, | |
| "input": rec["question"], | |
| "output": rec["answer"], | |
| "source": rec["source"], | |
| "category": rec.get("category", ""), | |
| } | |
| f.write(json.dumps(line, ensure_ascii=False) + "\n") | |
| logger.info("์ ์ฅ ์๋ฃ: %s (%d ๋ ์ฝ๋)", filepath, len(records)) | |