|
|
import pandas as pd |
|
|
import os |
|
|
from typing import List |
|
|
|
|
|
|
|
|
INPUT_CSV_PATH = "/home/hsichen/part_time/BERT_finetune/dataset_pretrain/Experiment_sentences_training_filtered_part1.csv" |
|
|
|
|
|
OUTPUT_TXT_PATH = "/home/hsichen/part_time/BERT_finetune/dataset_pretrain/domain_corpus.txt" |
|
|
|
|
|
ENCODING = 'utf-8' |
|
|
|
|
|
def prepare_dapt_data(input_csv_path: str, output_txt_path: str, encoding: str): |
|
|
""" |
|
|
从 CSV 文件中提取 'sentence' 列,并保存为纯文本文件,每行一个句子。 |
|
|
|
|
|
Args: |
|
|
input_csv_path: 原始 CSV 文件的路径。 |
|
|
output_txt_path: 目标纯文本文件的路径。 |
|
|
encoding: 文件编码。 |
|
|
""" |
|
|
print(f"--- 1. 读取数据: {input_csv_path} ---") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
df = pd.read_csv(input_csv_path) |
|
|
except FileNotFoundError: |
|
|
print(f"错误:输入文件未找到在路径: {input_csv_path}") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"读取 CSV 文件时发生错误: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'sentence' not in df.columns: |
|
|
print("错误:CSV 文件中未找到 'sentence' 列。请检查列名是否正确。") |
|
|
return |
|
|
|
|
|
|
|
|
sentences: List[str] = df['sentence'].dropna().astype(str).tolist() |
|
|
|
|
|
if not sentences: |
|
|
print("警告:'sentence' 列中没有有效数据,无法生成语料库。") |
|
|
return |
|
|
|
|
|
|
|
|
sentences = [s.strip() for s in sentences] |
|
|
|
|
|
print(f"提取到 {len(sentences)} 条有效句子。") |
|
|
|
|
|
|
|
|
print(f"--- 3. 保存至纯文本文件: {output_txt_path} ---") |
|
|
|
|
|
|
|
|
try: |
|
|
with open(output_txt_path, 'w', encoding=encoding) as f: |
|
|
f.write('\n'.join(sentences)) |
|
|
|
|
|
print(f"数据成功保存!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"写入文件时发生错误: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
prepare_dapt_data(INPUT_CSV_PATH, OUTPUT_TXT_PATH, ENCODING) |