File size: 6,364 Bytes
79f74b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c9920
 
79f74b7
 
b6c9920
79f74b7
 
 
 
 
 
b6c9920
 
 
 
 
79f74b7
b6c9920
 
 
 
79f74b7
 
b6c9920
 
 
 
 
79f74b7
 
 
b6c9920
79f74b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import re
import glob
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from datasets import Dataset
import torch


def load_and_process_data(data_dir: str) -> str:
    """
    Load all .en.txt files, remove timestamps, and concatenate with [BRK].
    
    Args:
        data_dir: Directory containing the .en.txt files
        
    Returns:
        Concatenated text with [BRK] separators
    """
    pattern = os.path.join(data_dir, "*.en.txt")
    files = glob.glob(pattern)
    
    if not files:
        raise ValueError(f"No .en.txt files found in {data_dir}")
    
    print(f"Found {len(files)} .en.txt files")
    
    all_segments = []
    
    for file_path in sorted(files):
        print(f"Processing {os.path.basename(file_path)}...")
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                
                # Remove timestamps in brackets like [0.00] or [2.30]
                # Pattern matches [number.number] or [number:number:number]
                line = re.sub(r'\[\d+\.?\d*\]', '', line)
                line = line.strip()
                
                if line:  # Only add non-empty lines after timestamp removal
                    all_segments.append(line)
    
    # Concatenate all segments with [BRK]
    concatenated_text = " [BRK] ".join(all_segments)
    
    print(f"Total segments: {len(all_segments)}")
    print(f"Total text length: {len(concatenated_text)} characters")
    
    return concatenated_text


def prepare_dataset(text: str, tokenizer, max_length: int = 512):
    """
    Tokenize the text and create a dataset for training.
    Preserves [BRK] tokens in the training data so the model can learn to generate them.
    Splits by token count only, not by [BRK] boundaries.
    
    Args:
        text: The concatenated text with [BRK] tokens
        tokenizer: The tokenizer to use
        max_length: Maximum sequence length
        
    Returns:
        Dataset ready for training
    """
    # Tokenize the entire text first to split by token count
    # This preserves [BRK] tokens within chunks
    print("Tokenizing full text...")
    full_tokens = tokenizer(text, add_special_tokens=False, return_offsets_mapping=False)
    input_ids = full_tokens['input_ids']
    
    # Split into chunks of max_length tokens
    # The tokenizer will add CLS and SEP tokens, so we use max_length directly
    # and let truncation handle it, or we can be more precise
    chunk_size = max_length - 2  # Reserve space for CLS and SEP tokens
    examples = []
    
    for i in range(0, len(input_ids), chunk_size):
        chunk_ids = input_ids[i:i + chunk_size]
        # Decode back to text to preserve [BRK] tokens, then re-tokenize with special tokens
        chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=False)
        examples.append(chunk_text)
    
    print(f"Created {len(examples)} training examples")
    
    # Tokenize all examples with proper special tokens
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
        )
    
    dataset = Dataset.from_dict({"text": examples})
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"],
    )
    
    return tokenized_dataset


def main():
    # Configuration
    model_name = "answerdotai/ModernBERT-large"
    data_dir = "/home/allen/Codes/metricsubs-chunktranslate/data"
    output_dir = "."
    
    print("=" * 60)
    print("ModernBERT-large Fine-tuning Script")
    print("=" * 60)
    
    # Step 1: Load model and tokenizer
    print("\n[1/4] Loading model and tokenizer from HuggingFace...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(model_name)
    
    # Add [BRK] as a special token
    print("Adding [BRK] as a special token...")
    special_tokens_dict = {"additional_special_tokens": ["[BRK]"]}
    tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))
    
    print(f"Model loaded: {model_name}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Vocabulary size: {len(tokenizer)}")
    
    # Step 2: Load and process data
    print("\n[2/4] Loading and processing training data...")
    concatenated_text = load_and_process_data(data_dir)
    
    # Step 3: Prepare dataset
    print("\n[3/4] Preparing dataset...")
    train_dataset = prepare_dataset(concatenated_text, tokenizer, max_length=512)
    
    # Step 4: Set up training
    print("\n[4/4] Setting up training...")
    
    # Data collator for MLM
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=0.15,
    )
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        weight_decay=0.01,
        warmup_steps=500,
        logging_steps=100,
        save_steps=1000,
        save_total_limit=3,
        prediction_loss_only=True,
        fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
        dataloader_pin_memory=True,
    )
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )
    
    # Train
    print("\nStarting training...")
    print(f"Training on {'GPU' if torch.cuda.is_available() else 'CPU'}")
    trainer.train()
    
    # Save the final model
    print(f"\nSaving model to {output_dir}...")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    print("\n" + "=" * 60)
    print("Fine-tuning complete!")
    print(f"Model saved to: {os.path.abspath(output_dir)}")
    print("=" * 60)


if __name__ == "__main__":
    main()