umyunsang's picture
Upload folder using huggingface_hub
9e65b56 verified
"""๋ฏผ์›๋‹ต๋ณ€ ํ•™์Šต ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ."""
from __future__ import annotations
import hashlib
import json
import logging
import random
from pathlib import Path
from typing import Any
from .config import DataConfig
from .parsers import AdminLawParser, GovQALocalParser, GovQAParser, GukripParser
logger = logging.getLogger(__name__)
class CivilResponseDataPipeline:
"""AI Hub ์›์‹œ ๋ฐ์ดํ„ฐ๋ฅผ instruction-tuning JSONL๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ."""
INSTRUCTION_TEXT = "๋‹ค์Œ ๋ฏผ์›์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•ด ์ฃผ์„ธ์š”."
def __init__(self, config: DataConfig | None = None):
self.config = config or DataConfig()
def run(self) -> dict[str, int]:
"""์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰. ๊ฒฐ๊ณผ ํ†ต๊ณ„ ๋ฐ˜ํ™˜."""
records: list[dict] = []
logger.info("71852 ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ์‹œ์ž‘")
records_71852 = self._process_71852()
logger.info("71852 ๋ฐ์ดํ„ฐ %d๊ฐœ ์ˆ˜์ง‘", len(records_71852))
records.extend(records_71852)
logger.info("71847 ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ์‹œ์ž‘")
records_71847 = self._process_71847()
logger.info("71847 ๋ฐ์ดํ„ฐ %d๊ฐœ ์ˆ˜์ง‘", len(records_71847))
records.extend(records_71847)
logger.info("์ค‘๋ณต ์ œ๊ฑฐ ์ „ ์ด %d๊ฐœ", len(records))
records = self._deduplicate(records)
logger.info("์ค‘๋ณต ์ œ๊ฑฐ ํ›„ %d๊ฐœ", len(records))
records = self._filter(records)
logger.info("ํ•„ํ„ฐ๋ง ํ›„ %d๊ฐœ", len(records))
train, val = self._split(records)
logger.info("train=%d, val=%d", len(train), len(val))
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
self._save_jsonl(train, output_dir / "train.jsonl")
self._save_jsonl(val, output_dir / "val.jsonl")
return {
"total": len(records),
"train": len(train),
"val": len(val),
}
# ------------------------------------------------------------------
# ๋ฐ์ดํ„ฐ์…‹๋ณ„ ์ฒ˜๋ฆฌ
# ------------------------------------------------------------------
def _process_71852(self) -> list[dict]:
base = Path(self.config.raw_dir) / "71852"
records: list[dict] = []
# ๊ตญ๋ฆฝ์•„์‹œ์•„๋ฌธํ™”์ „๋‹น
gukrp = GukripParser()
for split in ("train", "val"):
dir_path = base / split / "๊ตญ๋ฆฝ"
if dir_path.exists():
records.extend(self._parse_dir(gukrp, dir_path))
# ์ค‘์•™ํ–‰์ •๊ธฐ๊ด€
gov_central = GovQAParser()
for split in ("train", "val"):
dir_path = base / split / "์ค‘์•™"
if dir_path.exists():
records.extend(self._parse_dir(gov_central, dir_path))
# ์ง€๋ฐฉํ–‰์ •๊ธฐ๊ด€
gov_local = GovQALocalParser()
for split in ("train", "val"):
dir_path = base / split / "์ง€๋ฐฉ"
if dir_path.exists():
records.extend(self._parse_dir(gov_local, dir_path))
return records
def _process_71847(self) -> list[dict]:
base = Path(self.config.raw_dir) / "71847"
records: list[dict] = []
# ๊ฒฐ์ •๋ก€ QA
decision_parser = AdminLawParser(source_label="71847_๊ฒฐ์ •๋ก€")
dir_path = base / "TL_๊ฒฐ์ •๋ก€_QA"
if dir_path.exists():
records.extend(self._parse_dir(decision_parser, dir_path))
# ๋ฒ•๋ น QA
law_parser = AdminLawParser(source_label="71847_๋ฒ•๋ น")
dir_path = base / "TL_๋ฒ•๋ น_QA"
if dir_path.exists():
records.extend(self._parse_dir(law_parser, dir_path))
return records
# ------------------------------------------------------------------
# ์œ ํ‹ธ๋ฆฌํ‹ฐ
# ------------------------------------------------------------------
@staticmethod
def _parse_dir(parser: Any, dir_path: Path) -> list[dict]:
records: list[dict] = []
json_files = list(dir_path.glob("*.json"))
logger.debug(" %s: %d ํŒŒ์ผ", dir_path, len(json_files))
for filepath in json_files:
try:
records.extend(parser.parse(filepath))
except Exception as exc: # noqa: BLE001
logger.warning("ํŒŒ์‹ฑ ์‹คํŒจ %s: %s", filepath, exc)
return records
def _deduplicate(self, records: list[dict]) -> list[dict]:
"""์งˆ๋ฌธ+๋‹ต๋ณ€ ํ•ด์‹œ ๊ธฐ๋ฐ˜ ์ค‘๋ณต ์ œ๊ฑฐ."""
seen: set[str] = set()
unique: list[dict] = []
for rec in records:
key = hashlib.md5( # nosec B324
(rec["question"] + rec["answer"]).encode("utf-8"),
usedforsecurity=False,
).hexdigest()
if key not in seen:
seen.add(key)
unique.append(rec)
return unique
def _filter(self, records: list[dict]) -> list[dict]:
"""๊ธธ์ด ํ•„ํ„ฐ๋ง."""
filtered: list[dict] = []
for rec in records:
answer_len = len(rec["answer"])
question_len = len(rec["question"])
if answer_len < self.config.min_answer_length:
continue
if answer_len > self.config.max_answer_length:
continue
if question_len < self.config.min_question_length:
continue
filtered.append(rec)
return filtered
def _split(self, records: list[dict]) -> tuple[list[dict], list[dict]]:
"""train/val ๋ถ„๋ฆฌ (์…”ํ”Œ ํ›„ ๋น„์œจ ๋ถ„ํ• )."""
shuffled = list(records)
random.seed(42)
random.shuffle(shuffled)
split_idx = int(len(shuffled) * self.config.train_ratio)
return shuffled[:split_idx], shuffled[split_idx:]
def _save_jsonl(self, records: list[dict], filepath: Path) -> None:
"""Instruction-tuning ํ‘œ์ค€ JSONL ํ˜•์‹์œผ๋กœ ์ €์žฅ."""
filepath = Path(filepath)
with open(filepath, "w", encoding="utf-8") as f:
for rec in records:
line = {
"instruction": self.INSTRUCTION_TEXT,
"input": rec["question"],
"output": rec["answer"],
"source": rec["source"],
"category": rec.get("category", ""),
}
f.write(json.dumps(line, ensure_ascii=False) + "\n")
logger.info("์ €์žฅ ์™„๋ฃŒ: %s (%d ๋ ˆ์ฝ”๋“œ)", filepath, len(records))