speech2phone-ctc / train.py
indiejoseph's picture
Create train.py
296ac4c verified
import os
import torch
from transformers import Trainer, TrainingArguments, Wav2Vec2CTCTokenizer
import torch.nn.functional as F
from models.ctc_model import CTCTransformerModel, CTCTransformerConfig
from data import DataCollatorCTCWithPadding, SpeechTokenPhonemeDataset
import evaluate
import numpy as np
import pandas as pd
import logging
import warnings
os.environ["WANDB_PROJECT"] = "speech-phoneme-ctc"
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
df = pd.read_csv(
"dataset.csv",
)
# Dataset
vocab_path = "vocab/vocab.json"
tokenizer = Wav2Vec2CTCTokenizer(
vocab_path,
unk_token="[UNK]",
pad_token="[PAD]",
word_delimiter_token="|",
)
vocab = tokenizer.get_vocab()
vocab_inv = {v: k for k, v in vocab.items()}
num_speech_tokens = 6561
# ===== MODEL SETUP =====
config = CTCTransformerConfig(
vocab_size=num_speech_tokens,
num_labels=len(tokenizer),
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
num_hidden_layers=12,
max_position_embeddings=1024,
label2id=vocab,
id2label=vocab_inv,
pad_token_id=tokenizer.pad_token_id, # output padding token
src_pad_token_id=num_speech_tokens, # input padding token
)
model = CTCTransformerModel(config)
dataset = SpeechTokenPhonemeDataset(df, tokenizer=tokenizer)
train_valid_dataset = dataset.train_test_split(test_size=0.05, random_state=42)
train_dataset = train_valid_dataset["train"]
eval_dataset = train_valid_dataset["test"]
collator = DataCollatorCTCWithPadding(
pad_token_id=num_speech_tokens, label_pad_token_id=tokenizer.pad_token_id
)
# ===== METRICS =====
cer_metric = evaluate.load("cer")
def compute_metrics(pred):
label_ids = pred.label_ids
logits = pred.predictions
log_probs = F.log_softmax(torch.tensor(logits), dim=-1)
pred_ids = np.argmax(log_probs, axis=-1)
# Replace -100 with pad token ID
label_ids[label_ids == -100] = tokenizer.pad_token_id
# Decode predictions and references
pred_str = tokenizer.batch_decode(pred_ids)
label_str = tokenizer.batch_decode(label_ids, group_tokens=False)
# Calculate WER and CER
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer}
# Check vocabulary compatibility and print more detailed diagnostic info
print(f"Model vocab size: {model.config.vocab_size}")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(
f"Vocabulary: {list(tokenizer.get_vocab().keys())[:10]}... (showing first 10 tokens)"
)
print("Training dataset size:", len(train_dataset))
print("Evaluation dataset size:", len(eval_dataset))
if model.config.vocab_size != len(tokenizer.get_vocab()):
print("WARNING: Vocabulary size mismatch between model and tokenizer!")
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=64,
per_device_eval_batch_size=16,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=10,
num_train_epochs=10,
learning_rate=1e-4,
weight_decay=0.005,
warmup_ratio=0.1,
logging_steps=100,
logging_dir="./logs",
gradient_accumulation_steps=1,
bf16=True,
report_to="wandb",
remove_unused_columns=False,
dataloader_num_workers=4,
include_inputs_for_metrics=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collator,
compute_metrics=compute_metrics,
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {training_args.num_train_epochs}")
logger.info(
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
)
logger.info(
f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {training_args.max_steps}")
logger.info(f" Logging steps = {training_args.logging_steps}")
logger.info(f" Learning rate = {training_args.learning_rate}")
logger.info(f" Weight decay = {training_args.weight_decay}")
logger.info(f" Warmup steps = {training_args.warmup_steps}")
logger.info(f" Save total limit = {training_args.save_total_limit}")
train_res = trainer.train()
trainer.save_model()
trainer.save_state()
metrics = train_res.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
metrics = trainer.evaluate()
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
with open("results/train.log", "w") as f:
for obj in trainer.state.log_history:
f.write(str(obj))
f.write("\n")
print("- Training complete.")