# SillokBert-Scratch 프로젝트 3단계: 사전학습 (Pre-training) # ----------------------------------------------------------------- # 1, 2단계 결과물을 사용하여 Masked Language Modeling(MLM)으로 # 사전학습을 진행합니다. # ----------------------------------------------------------------- import os from pathlib import Path from transformers import ( BertConfig, BertForMaskedLM, PreTrainedTokenizerFast, DataCollatorForLanguageModeling, Trainer, TrainingArguments, ) from datasets import load_dataset from itertools import chain def pretrain_sillok_bert(): """SillokBert-Scratch 모델의 사전학습을 수행합니다.""" # --- 경로 설정 --- project_dir = Path("/home/work/baro/sillok/sillok_scratch_20250626") tokenizer_dir = project_dir / "sillok_tokenizer_bpe_preprocessed" tokenizer_file = tokenizer_dir / "tokenizer.json" dataset_dir = "/home/work/baro/sillok25060103/preprocessed_corpus/" output_dir = project_dir / "sillokbert_scratch_pretraining_output" print("--- 3. SillokBert-Scratch Pre-training ---") # --- 토크나이저 로드 --- tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_file)) if tokenizer.pad_token is None: tokenizer.pad_token = '[PAD]' if tokenizer.mask_token is None: tokenizer.mask_token = '[MASK]' # --- 모델 아키텍처 로드 --- block_size = 512 config = BertConfig( vocab_size=tokenizer.vocab_size, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, max_position_embeddings=block_size, pad_token_id=tokenizer.pad_token_id, ) model = BertForMaskedLM(config=config) # --- 데이터셋 준비 --- dataset = load_dataset('text', data_files={ 'train': os.path.join(dataset_dir, 'train.txt'), 'validation': os.path.join(dataset_dir, 'validation.txt'), }) def tokenize_function(examples): return tokenizer(examples['text'], add_special_tokens=False, return_special_tokens_mask=False) tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text']) def group_texts(examples): concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = (total_length // block_size) * block_size result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} return result lm_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4) # --- Trainer 설정 및 훈련 --- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=10, per_device_train_batch_size=4, per_device_eval_batch_size=8, gradient_accumulation_steps=4, gradient_checkpointing=True, learning_rate=5e-5, warmup_steps=1000, weight_decay=0.01, logging_dir=output_dir / 'logs', logging_steps=500, save_steps=2000, eval_strategy="steps", eval_steps=2000, load_best_model_at_end=True, fp16=True, ) trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=lm_datasets['train'], eval_dataset=lm_datasets['validation'], ) print("사전학습을 시작합니다...") trainer.train() # --- 최종 모델 저장 --- final_model_path = output_dir / "final_model" trainer.save_model(final_model_path) tokenizer.save_pretrained(final_model_path) print(f"\n🎉 사전학습 완료. 최종 모델 저장 경로: {final_model_path}") if __name__ == "__main__": pretrain_sillok_bert()