metricv's picture
Update model
b6c9920 verified
import os
import re
import glob
from pathlib import Path
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from datasets import Dataset
import torch
def load_and_process_data(data_dir: str) -> str:
"""
Load all .en.txt files, remove timestamps, and concatenate with [BRK].
Args:
data_dir: Directory containing the .en.txt files
Returns:
Concatenated text with [BRK] separators
"""
pattern = os.path.join(data_dir, "*.en.txt")
files = glob.glob(pattern)
if not files:
raise ValueError(f"No .en.txt files found in {data_dir}")
print(f"Found {len(files)} .en.txt files")
all_segments = []
for file_path in sorted(files):
print(f"Processing {os.path.basename(file_path)}...")
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line: # Skip empty lines
continue
# Remove timestamps in brackets like [0.00] or [2.30]
# Pattern matches [number.number] or [number:number:number]
line = re.sub(r'\[\d+\.?\d*\]', '', line)
line = line.strip()
if line: # Only add non-empty lines after timestamp removal
all_segments.append(line)
# Concatenate all segments with [BRK]
concatenated_text = " [BRK] ".join(all_segments)
print(f"Total segments: {len(all_segments)}")
print(f"Total text length: {len(concatenated_text)} characters")
return concatenated_text
def prepare_dataset(text: str, tokenizer, max_length: int = 512):
"""
Tokenize the text and create a dataset for training.
Preserves [BRK] tokens in the training data so the model can learn to generate them.
Splits by token count only, not by [BRK] boundaries.
Args:
text: The concatenated text with [BRK] tokens
tokenizer: The tokenizer to use
max_length: Maximum sequence length
Returns:
Dataset ready for training
"""
# Tokenize the entire text first to split by token count
# This preserves [BRK] tokens within chunks
print("Tokenizing full text...")
full_tokens = tokenizer(text, add_special_tokens=False, return_offsets_mapping=False)
input_ids = full_tokens['input_ids']
# Split into chunks of max_length tokens
# The tokenizer will add CLS and SEP tokens, so we use max_length directly
# and let truncation handle it, or we can be more precise
chunk_size = max_length - 2 # Reserve space for CLS and SEP tokens
examples = []
for i in range(0, len(input_ids), chunk_size):
chunk_ids = input_ids[i:i + chunk_size]
# Decode back to text to preserve [BRK] tokens, then re-tokenize with special tokens
chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=False)
examples.append(chunk_text)
print(f"Created {len(examples)} training examples")
# Tokenize all examples with proper special tokens
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding="max_length",
)
dataset = Dataset.from_dict({"text": examples})
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
return tokenized_dataset
def main():
# Configuration
model_name = "answerdotai/ModernBERT-large"
data_dir = "/home/allen/Codes/metricsubs-chunktranslate/data"
output_dir = "."
print("=" * 60)
print("ModernBERT-large Fine-tuning Script")
print("=" * 60)
# Step 1: Load model and tokenizer
print("\n[1/4] Loading model and tokenizer from HuggingFace...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
# Add [BRK] as a special token
print("Adding [BRK] as a special token...")
special_tokens_dict = {"additional_special_tokens": ["[BRK]"]}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
print(f"Model loaded: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocabulary size: {len(tokenizer)}")
# Step 2: Load and process data
print("\n[2/4] Loading and processing training data...")
concatenated_text = load_and_process_data(data_dir)
# Step 3: Prepare dataset
print("\n[3/4] Preparing dataset...")
train_dataset = prepare_dataset(concatenated_text, tokenizer, max_length=512)
# Step 4: Set up training
print("\n[4/4] Setting up training...")
# Data collator for MLM
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=0.15,
)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
weight_decay=0.01,
warmup_steps=500,
logging_steps=100,
save_steps=1000,
save_total_limit=3,
prediction_loss_only=True,
fp16=torch.cuda.is_available(), # Use mixed precision if GPU available
dataloader_pin_memory=True,
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
# Train
print("\nStarting training...")
print(f"Training on {'GPU' if torch.cuda.is_available() else 'CPU'}")
trainer.train()
# Save the final model
print(f"\nSaving model to {output_dir}...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("\n" + "=" * 60)
print("Fine-tuning complete!")
print(f"Model saved to: {os.path.abspath(output_dir)}")
print("=" * 60)
if __name__ == "__main__":
main()