SillokBert / scripts /prepare_data.py
ddokbaro's picture
Upload prepare_data.py
3c309fa verified
# -*- coding: utf-8 -*-
import os
import glob
import re
import random
import logging
import argparse
from tqdm import tqdm
from lxml import etree
import numpy as np
# --- ๋กœ๊น… ์„ค์ • ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
def extract_text_from_paragraph_xml(paragraph_element):
"""
<paragraph> XML ์š”์†Œ์—์„œ <annotation> ํƒœ๊ทธ๋ฅผ ์ œ์™ธํ•œ ์ˆœ์ˆ˜ ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
"""
try:
# XPath๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ <annotation>์˜ ์ž์†์ด ์•„๋‹Œ ๋ชจ๋“  text ๋…ธ๋“œ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
text_nodes = paragraph_element.xpath("descendant::text()[not(ancestor::annotation)]")
full_text = ''.join(text_nodes)
# ๋ถˆํ•„์š”ํ•œ ๊ณต๋ฐฑ์„ ์ •๊ทœํ™”ํ•ฉ๋‹ˆ๋‹ค.
cleaned_text = re.sub(r'\s+', ' ', full_text).strip()
return cleaned_text
except Exception as e:
logging.warning(f"ํ…์ŠคํŠธ ์ถ”์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ (extract_text_from_paragraph_xml): {e}")
return ""
def check_xml_structure(paragraph_element):
"""
ํ•ด๋‹น <paragraph>๊ฐ€ 'level5 > text > content' ๊ตฌ์กฐ ๋‚ด์— ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
"""
try:
parent = paragraph_element.getparent()
grandparent = parent.getparent() if parent is not None else None
greatgrandparent = grandparent.getparent() if grandparent is not None else None
return (
parent is not None and parent.tag == 'content' and
grandparent is not None and grandparent.tag == 'text' and
greatgrandparent is not None and greatgrandparent.tag == 'level5'
)
except AttributeError:
return False
def save_text_data_to_file(filepath, data_list, description="Saving file"):
"""
ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ง€์ •๋œ ๊ฒฝ๋กœ์— ํŒŒ์ผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
"""
try:
with open(filepath, 'w', encoding='utf-8') as f:
for line in tqdm(data_list, desc=description):
f.write(line + '\n')
logging.info(f" ์ด {len(data_list):,} ์ค„์„ {filepath}์— ์ €์žฅํ–ˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
logging.error(f" ํŒŒ์ผ ์ €์žฅ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ {filepath}: {e}")
def calculate_text_stats(data_list):
"""
๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ์˜ ํ†ต๊ณ„ ์ •๋ณด(๊ฐœ์ˆ˜, ์ตœ์†Œ/์ตœ๋Œ€/ํ‰๊ท /์ค‘์•™๊ฐ’ ๊ธธ์ด)๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
"""
if not data_list:
return {"count": 0, "min_len": 0, "max_len": 0, "avg_len": 0.0, "median_len": 0.0}
lengths = [len(s) for s in data_list]
return {
"count": len(lengths),
"min_len": np.min(lengths) if lengths else 0,
"max_len": np.max(lengths) if lengths else 0,
"avg_len": np.mean(lengths) if lengths else 0.0,
"median_len": np.median(lengths) if lengths else 0.0,
}
def prepare_sillok_data(xml_dir, output_dir, min_len, markers_pattern, val_ratio, test_ratio, seed_val):
"""
๋ฉ”์ธ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜.
XML ๋””๋ ‰ํ† ๋ฆฌ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์–ด ์ •์ œํ•˜๊ณ , ํ•™์Šต/๊ฒ€์ฆ/ํ…Œ์ŠคํŠธ ์šฉ์œผ๋กœ ๋ถ„ํ• ํ•˜์—ฌ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
"""
logging.info("--- ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ์‹œ์ž‘ ---")
logging.info(f"1. XML ์›๋ณธ ๊ฒฝ๋กœ: {xml_dir}")
logging.info(f"2. ์ „์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ๋ฌผ ์ €์žฅ ๊ฒฝ๋กœ: {output_dir}")
logging.info(f"3. ์ตœ์†Œ ๋ฌธ๋‹จ ๊ธธ์ด ํ•„ํ„ฐ: {min_len}")
logging.info(f"4. ์ œ๊ฑฐํ•  ๋ฌธ๋‹จ ์‹œ์ž‘ ๊ธฐํ˜ธ (์ •๊ทœ์‹): '{markers_pattern}'")
logging.info(f"5. ๋ถ„ํ•  ๋น„์œจ (ํ•™์Šต/๊ฒ€์ฆ/ํ…Œ์ŠคํŠธ): {1 - val_ratio - test_ratio:.2f}/{val_ratio:.2f}/{test_ratio:.2f}")
logging.info(f"6. ๋ถ„ํ•  ์‹œ ์‚ฌ์šฉํ•  ๋žœ๋ค ์‹œ๋“œ: {seed_val}")
os.makedirs(output_dir, exist_ok=True)
# ์ง€์ •๋œ ๋””๋ ‰ํ† ๋ฆฌ ํ•˜์œ„์˜ ๋ชจ๋“  .xml ํŒŒ์ผ ๊ฒ€์ƒ‰
xml_files = glob.glob(os.path.join(xml_dir, '**', '*.xml'), recursive=True)
if not xml_files:
logging.error(f"์˜ค๋ฅ˜: {xml_dir} ๊ฒฝ๋กœ์—์„œ XML ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
return
logging.info(f"\n์ด {len(xml_files)}๊ฐœ์˜ XML ํŒŒ์ผ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค.")
all_valid_paragraphs = []
logging.info("\nXML ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜์—ฌ ๋ฌธ๋‹จ์„ ์ถ”์ถœํ•˜๊ณ  ํ•„ํ„ฐ๋งํ•ฉ๋‹ˆ๋‹ค...")
for xml_file_path in tqdm(xml_files, desc="XML ํŒŒ์ผ ์ฒ˜๋ฆฌ ์ค‘"):
try:
tree = etree.parse(xml_file_path)
root = tree.getroot()
paragraphs = root.xpath('//level5//paragraph')
for para_element in paragraphs:
if not check_xml_structure(para_element):
continue
text = extract_text_from_paragraph_xml(para_element)
if not text or len(text) < min_len:
continue
processed_text = re.sub(markers_pattern, '', text).strip()
if processed_text:
all_valid_paragraphs.append(processed_text)
except Exception as e:
logging.warning(f"\nํŒŒ์ผ ์ฒ˜๋ฆฌ ์ค‘ ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜ ๋ฐœ์ƒ ({xml_file_path}): {e}")
continue
logging.info(f"\nXML ์ฒ˜๋ฆฌ ์™„๋ฃŒ. ์ด {len(all_valid_paragraphs):,}๊ฐœ์˜ ์œ ํšจํ•œ ๋ฌธ๋‹จ์„ ์ถ”์ถœํ–ˆ์Šต๋‹ˆ๋‹ค.")
if not all_valid_paragraphs:
logging.error("์œ ํšจํ•œ ๋ฌธ๋‹จ์ด ์—†์–ด ์ฒ˜๋ฆฌ๋ฅผ ์ค‘๋‹จํ•ฉ๋‹ˆ๋‹ค.")
return
logging.info("\n๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šต, ๊ฒ€์ฆ, ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค...")
random.seed(seed_val)
random.shuffle(all_valid_paragraphs)
total_count = len(all_valid_paragraphs)
test_idx = int(total_count * test_ratio)
valid_idx = test_idx + int(total_count * val_ratio)
test_data = all_valid_paragraphs[:test_idx]
valid_data = all_valid_paragraphs[test_idx:valid_idx]
train_data = all_valid_paragraphs[valid_idx:]
logging.info("\n--- ๋ฐ์ดํ„ฐ ๋ถ„ํ•  ๊ฒฐ๊ณผ ๋ฐ ํ†ต๊ณ„ (๊ธ€์ž ์ˆ˜ ๊ธฐ์ค€) ---")
datasets_for_stats = {"ํ•™์Šต": train_data, "๊ฒ€์ฆ": valid_data, "ํ…Œ์ŠคํŠธ": test_data}
for name, data in datasets_for_stats.items():
stats = calculate_text_stats(data)
percentage_of_total = (stats['count'] / total_count if total_count > 0 else 0.0)
log_msg = (f" - {name} ๋ฐ์ดํ„ฐ: {stats['count']:,} ๋ฌธ๋‹จ ({percentage_of_total:.1%}) | "
f"Min: {stats['min_len']}, Max: {stats['max_len']}, "
f"Avg: {stats['avg_len']:.1f}, Median: {stats['median_len']:.1f}")
logging.info(log_msg)
logging.info("\n๋ถ„ํ• ๋œ ๋ฐ์ดํ„ฐ์…‹์„ ํ…์ŠคํŠธ ํŒŒ์ผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค...")
train_filepath = os.path.join(output_dir, "train.txt")
valid_filepath = os.path.join(output_dir, "validation.txt")
test_filepath = os.path.join(output_dir, "test.txt")
save_text_data_to_file(train_filepath, train_data, "train.txt ์ €์žฅ ์ค‘")
save_text_data_to_file(valid_filepath, valid_data, "validation.txt ์ €์žฅ ์ค‘")
save_text_data_to_file(test_filepath, test_data, "test.txt ์ €์žฅ ์ค‘")
logging.info("\n--- ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ ---")
logging.info(f"๊ฒฐ๊ณผ๋ฌผ์ด ๋‹ค์Œ ๊ฒฝ๋กœ์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {output_dir}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="์กฐ์„ ์™•์กฐ์‹ค๋ก XML ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜์—ฌ ํ…์ŠคํŠธ ํŒŒ์ผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument("--xml_dir", type=str, required=True, help="XML ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ์›๋ณธ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ")
parser.add_argument("--output_dir", type=str, required=True, help="์ „์ฒ˜๋ฆฌ๋œ ํ…์ŠคํŠธ ํŒŒ์ผ์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ")
parser.add_argument("--min_len", type=int, default=10, help="ํ•„ํ„ฐ๋งํ•  ๋ฌธ๋‹จ์˜ ์ตœ์†Œ ๊ธ€์ž ์ˆ˜")
parser.add_argument("--val_ratio", type=float, default=0.05, help="๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹์˜ ๋น„์œจ")
parser.add_argument("--test_ratio", type=float, default=0.05, help="ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์…‹์˜ ๋น„์œจ")
parser.add_argument("--seed", type=int, default=85, help="๋ฐ์ดํ„ฐ ๋ถ„ํ•  ์‹œ ์‚ฌ์šฉํ•  ๋žœ๋ค ์‹œ๋“œ๊ฐ’")
args = parser.parse_args()
LEADING_MARKERS_PATTERN = r"^[โ—‹โ–ฒโ—โ—Žโ—‡โ—ˆโ–ทโ–ถโ–ฝโ–ผโ–ฃโ– โ–กโ–ชโ–ซโ˜žโ‡จ]+"
prepare_sillok_data(
xml_dir=args.xml_dir,
output_dir=args.output_dir,
min_len=args.min_len,
markers_pattern=LEADING_MARKERS_PATTERN,
val_ratio=args.val_ratio,
test_ratio=args.test_ratio,
seed_val=args.seed
)