# -*- coding: utf-8 -*- import os import glob import re import random import logging import argparse from tqdm import tqdm from lxml import etree import numpy as np # --- 로깅 설정 --- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()] ) def extract_text_from_paragraph_xml(paragraph_element): """ XML 요소에서 태그를 제외한 순수 텍스트를 추출합니다. """ try: # XPath를 사용하여 의 자손이 아닌 모든 text 노드를 가져옵니다. text_nodes = paragraph_element.xpath("descendant::text()[not(ancestor::annotation)]") full_text = ''.join(text_nodes) # 불필요한 공백을 정규화합니다. cleaned_text = re.sub(r'\s+', ' ', full_text).strip() return cleaned_text except Exception as e: logging.warning(f"텍스트 추출 중 오류 발생 (extract_text_from_paragraph_xml): {e}") return "" def check_xml_structure(paragraph_element): """ 해당 가 'level5 > text > content' 구조 내에 있는지 확인합니다. """ try: parent = paragraph_element.getparent() grandparent = parent.getparent() if parent is not None else None greatgrandparent = grandparent.getparent() if grandparent is not None else None return ( parent is not None and parent.tag == 'content' and grandparent is not None and grandparent.tag == 'text' and greatgrandparent is not None and greatgrandparent.tag == 'level5' ) except AttributeError: return False def save_text_data_to_file(filepath, data_list, description="Saving file"): """ 텍스트 리스트를 지정된 경로에 파일로 저장합니다. """ try: with open(filepath, 'w', encoding='utf-8') as f: for line in tqdm(data_list, desc=description): f.write(line + '\n') logging.info(f" 총 {len(data_list):,} 줄을 {filepath}에 저장했습니다.") except Exception as e: logging.error(f" 파일 저장 중 오류 발생 {filepath}: {e}") def calculate_text_stats(data_list): """ 데이터 리스트의 통계 정보(개수, 최소/최대/평균/중앙값 길이)를 계산합니다. """ if not data_list: return {"count": 0, "min_len": 0, "max_len": 0, "avg_len": 0.0, "median_len": 0.0} lengths = [len(s) for s in data_list] return { "count": len(lengths), "min_len": np.min(lengths) if lengths else 0, "max_len": np.max(lengths) if lengths else 0, "avg_len": np.mean(lengths) if lengths else 0.0, "median_len": np.median(lengths) if lengths else 0.0, } def prepare_sillok_data(xml_dir, output_dir, min_len, markers_pattern, val_ratio, test_ratio, seed_val): """ 메인 데이터 전처리 함수. XML 디렉토리에서 데이터를 읽어 정제하고, 학습/검증/테스트 용으로 분할하여 저장합니다. """ logging.info("--- 데이터 전처리 시작 ---") logging.info(f"1. XML 원본 경로: {xml_dir}") logging.info(f"2. 전처리 결과물 저장 경로: {output_dir}") logging.info(f"3. 최소 문단 길이 필터: {min_len}") logging.info(f"4. 제거할 문단 시작 기호 (정규식): '{markers_pattern}'") logging.info(f"5. 분할 비율 (학습/검증/테스트): {1 - val_ratio - test_ratio:.2f}/{val_ratio:.2f}/{test_ratio:.2f}") logging.info(f"6. 분할 시 사용할 랜덤 시드: {seed_val}") os.makedirs(output_dir, exist_ok=True) # 지정된 디렉토리 하위의 모든 .xml 파일 검색 xml_files = glob.glob(os.path.join(xml_dir, '**', '*.xml'), recursive=True) if not xml_files: logging.error(f"오류: {xml_dir} 경로에서 XML 파일을 찾을 수 없습니다.") return logging.info(f"\n총 {len(xml_files)}개의 XML 파일을 발견했습니다.") all_valid_paragraphs = [] logging.info("\nXML 파일을 처리하여 문단을 추출하고 필터링합니다...") for xml_file_path in tqdm(xml_files, desc="XML 파일 처리 중"): try: tree = etree.parse(xml_file_path) root = tree.getroot() paragraphs = root.xpath('//level5//paragraph') for para_element in paragraphs: if not check_xml_structure(para_element): continue text = extract_text_from_paragraph_xml(para_element) if not text or len(text) < min_len: continue processed_text = re.sub(markers_pattern, '', text).strip() if processed_text: all_valid_paragraphs.append(processed_text) except Exception as e: logging.warning(f"\n파일 처리 중 예상치 못한 오류 발생 ({xml_file_path}): {e}") continue logging.info(f"\nXML 처리 완료. 총 {len(all_valid_paragraphs):,}개의 유효한 문단을 추출했습니다.") if not all_valid_paragraphs: logging.error("유효한 문단이 없어 처리를 중단합니다.") return logging.info("\n데이터를 학습, 검증, 테스트 세트로 분할합니다...") random.seed(seed_val) random.shuffle(all_valid_paragraphs) total_count = len(all_valid_paragraphs) test_idx = int(total_count * test_ratio) valid_idx = test_idx + int(total_count * val_ratio) test_data = all_valid_paragraphs[:test_idx] valid_data = all_valid_paragraphs[test_idx:valid_idx] train_data = all_valid_paragraphs[valid_idx:] logging.info("\n--- 데이터 분할 결과 및 통계 (글자 수 기준) ---") datasets_for_stats = {"학습": train_data, "검증": valid_data, "테스트": test_data} for name, data in datasets_for_stats.items(): stats = calculate_text_stats(data) percentage_of_total = (stats['count'] / total_count if total_count > 0 else 0.0) log_msg = (f" - {name} 데이터: {stats['count']:,} 문단 ({percentage_of_total:.1%}) | " f"Min: {stats['min_len']}, Max: {stats['max_len']}, " f"Avg: {stats['avg_len']:.1f}, Median: {stats['median_len']:.1f}") logging.info(log_msg) logging.info("\n분할된 데이터셋을 텍스트 파일로 저장합니다...") train_filepath = os.path.join(output_dir, "train.txt") valid_filepath = os.path.join(output_dir, "validation.txt") test_filepath = os.path.join(output_dir, "test.txt") save_text_data_to_file(train_filepath, train_data, "train.txt 저장 중") save_text_data_to_file(valid_filepath, valid_data, "validation.txt 저장 중") save_text_data_to_file(test_filepath, test_data, "test.txt 저장 중") logging.info("\n--- 데이터 전처리 완료 ---") logging.info(f"결과물이 다음 경로에 저장되었습니다: {output_dir}") if __name__ == '__main__': parser = argparse.ArgumentParser(description="조선왕조실록 XML 데이터를 전처리하여 텍스트 파일로 변환합니다.") parser.add_argument("--xml_dir", type=str, required=True, help="XML 파일들이 있는 원본 디렉토리 경로") parser.add_argument("--output_dir", type=str, required=True, help="전처리된 텍스트 파일을 저장할 디렉토리 경로") parser.add_argument("--min_len", type=int, default=10, help="필터링할 문단의 최소 글자 수") parser.add_argument("--val_ratio", type=float, default=0.05, help="검증 데이터셋의 비율") parser.add_argument("--test_ratio", type=float, default=0.05, help="테스트 데이터셋의 비율") parser.add_argument("--seed", type=int, default=85, help="데이터 분할 시 사용할 랜덤 시드값") args = parser.parse_args() LEADING_MARKERS_PATTERN = r"^[○▲●◎◇◈▷▶▽▼▣■□▪▫☞⇨]+" prepare_sillok_data( xml_dir=args.xml_dir, output_dir=args.output_dir, min_len=args.min_len, markers_pattern=LEADING_MARKERS_PATTERN, val_ratio=args.val_ratio, test_ratio=args.test_ratio, seed_val=args.seed )