File size: 3,989 Bytes
170de4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# SillokBert-Scratch νλ‘μ νΈ 3λ¨κ³: μ¬μ νμ΅ (Pre-training)
# -----------------------------------------------------------------
# 1, 2λ¨κ³ κ²°κ³Όλ¬Όμ μ¬μ©νμ¬ Masked Language Modeling(MLM)μΌλ‘
# μ¬μ νμ΅μ μ§νν©λλ€.
# -----------------------------------------------------------------
import os
from pathlib import Path
from transformers import (
BertConfig,
BertForMaskedLM,
PreTrainedTokenizerFast,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from datasets import load_dataset
from itertools import chain
def pretrain_sillok_bert():
"""SillokBert-Scratch λͺ¨λΈμ μ¬μ νμ΅μ μνν©λλ€."""
# --- κ²½λ‘ μ€μ ---
project_dir = Path("/home/work/baro/sillok/sillok_scratch_20250626")
tokenizer_dir = project_dir / "sillok_tokenizer_bpe_preprocessed"
tokenizer_file = tokenizer_dir / "tokenizer.json"
dataset_dir = "/home/work/baro/sillok25060103/preprocessed_corpus/"
output_dir = project_dir / "sillokbert_scratch_pretraining_output"
print("--- 3. SillokBert-Scratch Pre-training ---")
# --- ν ν¬λμ΄μ λ‘λ ---
tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_file))
if tokenizer.pad_token is None: tokenizer.pad_token = '[PAD]'
if tokenizer.mask_token is None: tokenizer.mask_token = '[MASK]'
# --- λͺ¨λΈ μν€ν
μ² λ‘λ ---
block_size = 512
config = BertConfig(
vocab_size=tokenizer.vocab_size, hidden_size=768, num_hidden_layers=12,
num_attention_heads=12, intermediate_size=3072,
max_position_embeddings=block_size, pad_token_id=tokenizer.pad_token_id,
)
model = BertForMaskedLM(config=config)
# --- λ°μ΄ν°μ
μ€λΉ ---
dataset = load_dataset('text', data_files={
'train': os.path.join(dataset_dir, 'train.txt'),
'validation': os.path.join(dataset_dir, 'validation.txt'),
})
def tokenize_function(examples):
return tokenizer(examples['text'], add_special_tokens=False, return_special_tokens_mask=False)
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text'])
def group_texts(examples):
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
total_length = (total_length // block_size) * block_size
result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
return result
lm_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)
# --- Trainer μ€μ λ° νλ ¨ ---
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(
output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=10,
per_device_train_batch_size=4, per_device_eval_batch_size=8,
gradient_accumulation_steps=4, gradient_checkpointing=True,
learning_rate=5e-5, warmup_steps=1000, weight_decay=0.01,
logging_dir=output_dir / 'logs', logging_steps=500, save_steps=2000,
eval_strategy="steps", eval_steps=2000, load_best_model_at_end=True, fp16=True,
)
trainer = Trainer(
model=model, args=training_args, data_collator=data_collator,
train_dataset=lm_datasets['train'], eval_dataset=lm_datasets['validation'],
)
print("μ¬μ νμ΅μ μμν©λλ€...")
trainer.train()
# --- μ΅μ’
λͺ¨λΈ μ μ₯ ---
final_model_path = output_dir / "final_model"
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"\nπ μ¬μ νμ΅ μλ£. μ΅μ’
λͺ¨λΈ μ μ₯ κ²½λ‘: {final_model_path}")
if __name__ == "__main__":
pretrain_sillok_bert()
|