BERT / train_squad.py
JangTaeng's picture
Upload 7 files
05c8e16 verified
"""
SQuAD v1.1에 BERT를 파인튜닝하는 스크립트
==========================================
Devlin et al. (2019)의 §4.2 SQuAD 실험을 재현합니다.
논문은 추출형 QA를 passage 토큰들에 대한 두 개의 확률 분포 예측 문제로
정의합니다 - 답변 span의 시작 위치와 끝 위치 각각:
P_i = exp(S · T_i) / Σ_j exp(S · T_j)
끝 토큰도 같은 형태로, 학습되는 끝 벡터 E를 사용합니다. span (i, j)의
점수는 S·T_i + E·T_j 이고, j >= i 조건에서 가장 높은 점수의 span을 반환합니다.
§4.2의 하이퍼파라미터:
에폭 수 : 3
학습률 : 5e-5
배치 크기 : 32
실행 예시
---------
python train_squad.py \\
--model_name_or_path bert-base-uncased \\
--output_dir ./out/squad \\
--learning_rate 5e-5 \\
--num_train_epochs 3 \\
--per_device_train_batch_size 32
"""
from __future__ import annotations
import argparse
from datasets import load_dataset
from transformers import (
AutoModelForQuestionAnswering,
AutoTokenizer,
DefaultDataCollator,
Trainer,
TrainingArguments,
set_seed,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--model_name_or_path", type=str, default="bert-base-uncased")
p.add_argument("--output_dir", type=str, default="./out/squad")
p.add_argument("--max_length", type=int, default=384,
help="시퀀스 최대 길이. 논문은 SQuAD에서 384를 사용.")
p.add_argument("--doc_stride", type=int, default=128,
help="긴 지문을 겹치는 윈도우로 나눌 때의 stride.")
p.add_argument("--learning_rate", type=float, default=5e-5,
help="논문: 5e-5 (§4.2).")
p.add_argument("--num_train_epochs", type=float, default=3.0,
help="논문: 3 에폭 (§4.2).")
p.add_argument("--per_device_train_batch_size", type=int, default=32,
help="논문: 배치 크기 32 (§4.2).")
p.add_argument("--per_device_eval_batch_size", type=int, default=32)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main() -> None:
args = parse_args()
set_seed(args.seed)
# ------------------------------------------------------------------
# 1. 데이터셋 로드
# ------------------------------------------------------------------
raw = load_dataset("squad")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
def preprocess_train(examples):
# 논문은 (질문, 지문)을 단일 시퀀스로 패킹합니다. 질문은 segment 임베딩 A,
# 지문은 segment 임베딩 B를 사용합니다.
# `tokenizer(question, context, ...)`가 [SEP]와 segment id를 처리해줍니다.
questions = [q.strip() for q in examples["question"]]
tokenized = tokenizer(
questions,
examples["context"],
max_length=args.max_length,
truncation="only_second", # 질문은 절대 자르지 않음
stride=args.doc_stride, # 긴 지문은 겹치는 윈도우로 분할
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
# 하나의 지문이 여러 feature로 나뉘면, 각 feature를 원본 예제로 매핑하여
# 답변 span을 찾을 수 있게 합니다.
sample_mapping = tokenized.pop("overflow_to_sample_mapping")
offset_mapping = tokenized.pop("offset_mapping")
tokenized["start_positions"] = []
tokenized["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
input_ids = tokenized["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
sequence_ids = tokenized.sequence_ids(i)
sample_idx = sample_mapping[i]
answers = examples["answers"][sample_idx]
if len(answers["answer_start"]) == 0:
# §4.3: SQuAD v1.1은 모든 질문에 답이 있습니다.
# 만약 v2.0을 돌린다면 답이 없는 경우 [CLS]를 가리키게 합니다.
tokenized["start_positions"].append(cls_index)
tokenized["end_positions"].append(cls_index)
continue
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
# 지문 토큰 범위를 찾습니다 (sequence_id == 1).
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
# 답이 현재 윈도우 밖에 있으면 [CLS]를 가리킵니다.
if not (
offsets[token_start_index][0] <= start_char
and offsets[token_end_index][1] >= end_char
):
tokenized["start_positions"].append(cls_index)
tokenized["end_positions"].append(cls_index)
else:
while (
token_start_index < len(offsets)
and offsets[token_start_index][0] <= start_char
):
token_start_index += 1
tokenized["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized["end_positions"].append(token_end_index + 1)
return tokenized
train_dataset = raw["train"].map(
preprocess_train,
batched=True,
remove_columns=raw["train"].column_names,
desc="train 셋 토큰화 중",
)
# 평가에는 offset과 example_id를 유지해야 모델 출력을 문자 span으로
# 다시 매핑할 수 있습니다 (post-processing은 여기선 생략 -
# HuggingFace의 run_qa.py 참고).
eval_dataset = raw["validation"].map(
preprocess_train,
batched=True,
remove_columns=raw["validation"].column_names,
desc="validation 셋 토큰화 중",
)
# ------------------------------------------------------------------
# 2. 모델
# ------------------------------------------------------------------
model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
# ------------------------------------------------------------------
# 3. 학습
# ------------------------------------------------------------------
training_args = TrainingArguments(
output_dir=args.output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
weight_decay=0.01,
warmup_ratio=0.1,
seed=args.seed,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=DefaultDataCollator(),
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(f"\n모델이 {args.output_dir}에 저장되었습니다.")
print("참고: 공식 SQuAD EM/F1 메트릭을 얻으려면 post-processing이 필요합니다")
print("(https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering 참고).")
if __name__ == "__main__":
main()