SillokBert-Scratch / scripts /3_pretrain.py
ddokbaro's picture
Upload 15 files
170de4d verified
# SillokBert-Scratch ν”„λ‘œμ νŠΈ 3단계: μ‚¬μ „ν•™μŠ΅ (Pre-training)
# -----------------------------------------------------------------
# 1, 2단계 결과물을 μ‚¬μš©ν•˜μ—¬ Masked Language Modeling(MLM)으둜
# μ‚¬μ „ν•™μŠ΅μ„ μ§„ν–‰ν•©λ‹ˆλ‹€.
# -----------------------------------------------------------------
import os
from pathlib import Path
from transformers import (
BertConfig,
BertForMaskedLM,
PreTrainedTokenizerFast,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from datasets import load_dataset
from itertools import chain
def pretrain_sillok_bert():
"""SillokBert-Scratch λͺ¨λΈμ˜ μ‚¬μ „ν•™μŠ΅μ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€."""
# --- 경둜 μ„€μ • ---
project_dir = Path("/home/work/baro/sillok/sillok_scratch_20250626")
tokenizer_dir = project_dir / "sillok_tokenizer_bpe_preprocessed"
tokenizer_file = tokenizer_dir / "tokenizer.json"
dataset_dir = "/home/work/baro/sillok25060103/preprocessed_corpus/"
output_dir = project_dir / "sillokbert_scratch_pretraining_output"
print("--- 3. SillokBert-Scratch Pre-training ---")
# --- ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ ---
tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_file))
if tokenizer.pad_token is None: tokenizer.pad_token = '[PAD]'
if tokenizer.mask_token is None: tokenizer.mask_token = '[MASK]'
# --- λͺ¨λΈ μ•„ν‚€ν…μ²˜ λ‘œλ“œ ---
block_size = 512
config = BertConfig(
vocab_size=tokenizer.vocab_size, hidden_size=768, num_hidden_layers=12,
num_attention_heads=12, intermediate_size=3072,
max_position_embeddings=block_size, pad_token_id=tokenizer.pad_token_id,
)
model = BertForMaskedLM(config=config)
# --- 데이터셋 μ€€λΉ„ ---
dataset = load_dataset('text', data_files={
'train': os.path.join(dataset_dir, 'train.txt'),
'validation': os.path.join(dataset_dir, 'validation.txt'),
})
def tokenize_function(examples):
return tokenizer(examples['text'], add_special_tokens=False, return_special_tokens_mask=False)
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text'])
def group_texts(examples):
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
total_length = (total_length // block_size) * block_size
result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
return result
lm_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)
# --- Trainer μ„€μ • 및 ν›ˆλ ¨ ---
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(
output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=10,
per_device_train_batch_size=4, per_device_eval_batch_size=8,
gradient_accumulation_steps=4, gradient_checkpointing=True,
learning_rate=5e-5, warmup_steps=1000, weight_decay=0.01,
logging_dir=output_dir / 'logs', logging_steps=500, save_steps=2000,
eval_strategy="steps", eval_steps=2000, load_best_model_at_end=True, fp16=True,
)
trainer = Trainer(
model=model, args=training_args, data_collator=data_collator,
train_dataset=lm_datasets['train'], eval_dataset=lm_datasets['validation'],
)
print("μ‚¬μ „ν•™μŠ΅μ„ μ‹œμž‘ν•©λ‹ˆλ‹€...")
trainer.train()
# --- μ΅œμ’… λͺ¨λΈ μ €μž₯ ---
final_model_path = output_dir / "final_model"
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"\nπŸŽ‰ μ‚¬μ „ν•™μŠ΅ μ™„λ£Œ. μ΅œμ’… λͺ¨λΈ μ €μž₯ 경둜: {final_model_path}")
if __name__ == "__main__":
pretrain_sillok_bert()