File size: 3,989 Bytes
170de4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# SillokBert-Scratch ν”„λ‘œμ νŠΈ 3단계: μ‚¬μ „ν•™μŠ΅ (Pre-training)
# -----------------------------------------------------------------
# 1, 2단계 결과물을 μ‚¬μš©ν•˜μ—¬ Masked Language Modeling(MLM)으둜
# μ‚¬μ „ν•™μŠ΅μ„ μ§„ν–‰ν•©λ‹ˆλ‹€.
# -----------------------------------------------------------------
import os
from pathlib import Path
from transformers import (
    BertConfig,
    BertForMaskedLM,
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset
from itertools import chain

def pretrain_sillok_bert():
    """SillokBert-Scratch λͺ¨λΈμ˜ μ‚¬μ „ν•™μŠ΅μ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€."""
    # --- 경둜 μ„€μ • ---
    project_dir = Path("/home/work/baro/sillok/sillok_scratch_20250626")
    tokenizer_dir = project_dir / "sillok_tokenizer_bpe_preprocessed"
    tokenizer_file = tokenizer_dir / "tokenizer.json"
    dataset_dir = "/home/work/baro/sillok25060103/preprocessed_corpus/"
    output_dir = project_dir / "sillokbert_scratch_pretraining_output"

    print("--- 3. SillokBert-Scratch Pre-training ---")

    # --- ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ ---
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_file))
    if tokenizer.pad_token is None: tokenizer.pad_token = '[PAD]'
    if tokenizer.mask_token is None: tokenizer.mask_token = '[MASK]'

    # --- λͺ¨λΈ μ•„ν‚€ν…μ²˜ λ‘œλ“œ ---
    block_size = 512
    config = BertConfig(
        vocab_size=tokenizer.vocab_size, hidden_size=768, num_hidden_layers=12,
        num_attention_heads=12, intermediate_size=3072,
        max_position_embeddings=block_size, pad_token_id=tokenizer.pad_token_id,
    )
    model = BertForMaskedLM(config=config)

    # --- 데이터셋 μ€€λΉ„ ---
    dataset = load_dataset('text', data_files={
        'train': os.path.join(dataset_dir, 'train.txt'),
        'validation': os.path.join(dataset_dir, 'validation.txt'),
    })
    
    def tokenize_function(examples):
        return tokenizer(examples['text'], add_special_tokens=False, return_special_tokens_mask=False)

    tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text'])
    
    def group_texts(examples):
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
        return result

    lm_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)

    # --- Trainer μ„€μ • 및 ν›ˆλ ¨ ---
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
    
    training_args = TrainingArguments(
        output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=10,
        per_device_train_batch_size=4, per_device_eval_batch_size=8,
        gradient_accumulation_steps=4, gradient_checkpointing=True,
        learning_rate=5e-5, warmup_steps=1000, weight_decay=0.01,
        logging_dir=output_dir / 'logs', logging_steps=500, save_steps=2000,
        eval_strategy="steps", eval_steps=2000, load_best_model_at_end=True, fp16=True,
    )

    trainer = Trainer(
        model=model, args=training_args, data_collator=data_collator,
        train_dataset=lm_datasets['train'], eval_dataset=lm_datasets['validation'],
    )
    
    print("μ‚¬μ „ν•™μŠ΅μ„ μ‹œμž‘ν•©λ‹ˆλ‹€...")
    trainer.train()

    # --- μ΅œμ’… λͺ¨λΈ μ €μž₯ ---
    final_model_path = output_dir / "final_model"
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)
    print(f"\nπŸŽ‰ μ‚¬μ „ν•™μŠ΅ μ™„λ£Œ. μ΅œμ’… λͺ¨λΈ μ €μž₯ 경둜: {final_model_path}")

if __name__ == "__main__":
    pretrain_sillok_bert()