| | import os |
| | import torch |
| | from transformers import Trainer, TrainingArguments, Wav2Vec2CTCTokenizer |
| | import torch.nn.functional as F |
| | from models.ctc_model import CTCTransformerModel, CTCTransformerConfig |
| | from data import DataCollatorCTCWithPadding, SpeechTokenPhonemeDataset |
| | import evaluate |
| | import numpy as np |
| | import pandas as pd |
| | import logging |
| | import warnings |
| |
|
| | os.environ["WANDB_PROJECT"] = "speech-phoneme-ctc" |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | df = pd.read_csv( |
| | "dataset.csv", |
| | ) |
| |
|
| | |
| | vocab_path = "vocab/vocab.json" |
| | tokenizer = Wav2Vec2CTCTokenizer( |
| | vocab_path, |
| | unk_token="[UNK]", |
| | pad_token="[PAD]", |
| | word_delimiter_token="|", |
| | ) |
| | vocab = tokenizer.get_vocab() |
| | vocab_inv = {v: k for k, v in vocab.items()} |
| | num_speech_tokens = 6561 |
| |
|
| | |
| | config = CTCTransformerConfig( |
| | vocab_size=num_speech_tokens, |
| | num_labels=len(tokenizer), |
| | hidden_size=768, |
| | intermediate_size=3072, |
| | num_attention_heads=12, |
| | num_hidden_layers=12, |
| | max_position_embeddings=1024, |
| | label2id=vocab, |
| | id2label=vocab_inv, |
| | pad_token_id=tokenizer.pad_token_id, |
| | src_pad_token_id=num_speech_tokens, |
| | ) |
| | model = CTCTransformerModel(config) |
| | dataset = SpeechTokenPhonemeDataset(df, tokenizer=tokenizer) |
| | train_valid_dataset = dataset.train_test_split(test_size=0.05, random_state=42) |
| | train_dataset = train_valid_dataset["train"] |
| | eval_dataset = train_valid_dataset["test"] |
| | collator = DataCollatorCTCWithPadding( |
| | pad_token_id=num_speech_tokens, label_pad_token_id=tokenizer.pad_token_id |
| | ) |
| |
|
| | |
| | cer_metric = evaluate.load("cer") |
| |
|
| |
|
| | def compute_metrics(pred): |
| | label_ids = pred.label_ids |
| | logits = pred.predictions |
| | log_probs = F.log_softmax(torch.tensor(logits), dim=-1) |
| | pred_ids = np.argmax(log_probs, axis=-1) |
| |
|
| | |
| | label_ids[label_ids == -100] = tokenizer.pad_token_id |
| |
|
| | |
| | pred_str = tokenizer.batch_decode(pred_ids) |
| | label_str = tokenizer.batch_decode(label_ids, group_tokens=False) |
| |
|
| | |
| | cer = cer_metric.compute(predictions=pred_str, references=label_str) |
| |
|
| | return {"cer": cer} |
| |
|
| |
|
| | |
| | print(f"Model vocab size: {model.config.vocab_size}") |
| | print(f"Tokenizer vocab size: {len(tokenizer)}") |
| | print( |
| | f"Vocabulary: {list(tokenizer.get_vocab().keys())[:10]}... (showing first 10 tokens)" |
| | ) |
| | print("Training dataset size:", len(train_dataset)) |
| | print("Evaluation dataset size:", len(eval_dataset)) |
| |
|
| | if model.config.vocab_size != len(tokenizer.get_vocab()): |
| | print("WARNING: Vocabulary size mismatch between model and tokenizer!") |
| |
|
| |
|
| | training_args = TrainingArguments( |
| | output_dir="./results", |
| | per_device_train_batch_size=64, |
| | per_device_eval_batch_size=16, |
| | eval_strategy="epoch", |
| | save_strategy="epoch", |
| | save_total_limit=10, |
| | num_train_epochs=10, |
| | learning_rate=1e-4, |
| | weight_decay=0.005, |
| | warmup_ratio=0.1, |
| | logging_steps=100, |
| | logging_dir="./logs", |
| | gradient_accumulation_steps=1, |
| | bf16=True, |
| | report_to="wandb", |
| | remove_unused_columns=False, |
| | dataloader_num_workers=4, |
| | include_inputs_for_metrics=True, |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | data_collator=collator, |
| | compute_metrics=compute_metrics, |
| | ) |
| |
|
| | logger.info("***** Running training *****") |
| | logger.info(f" Num examples = {len(train_dataset)}") |
| | logger.info(f" Num Epochs = {training_args.num_train_epochs}") |
| | logger.info( |
| | f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" |
| | ) |
| | logger.info( |
| | f" Total train batch size (w. parallel, distributed & accumulation) = {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}" |
| | ) |
| | logger.info( |
| | f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}" |
| | ) |
| | logger.info(f" Total optimization steps = {training_args.max_steps}") |
| | logger.info(f" Logging steps = {training_args.logging_steps}") |
| | logger.info(f" Learning rate = {training_args.learning_rate}") |
| | logger.info(f" Weight decay = {training_args.weight_decay}") |
| | logger.info(f" Warmup steps = {training_args.warmup_steps}") |
| | logger.info(f" Save total limit = {training_args.save_total_limit}") |
| |
|
| | train_res = trainer.train() |
| |
|
| | trainer.save_model() |
| | trainer.save_state() |
| |
|
| | metrics = train_res.metrics |
| | metrics["train_samples"] = len(train_dataset) |
| | trainer.log_metrics("train", metrics) |
| | trainer.save_metrics("train", metrics) |
| |
|
| | metrics = trainer.evaluate() |
| | metrics["eval_samples"] = len(eval_dataset) |
| | trainer.log_metrics("eval", metrics) |
| | trainer.save_metrics("eval", metrics) |
| |
|
| | with open("results/train.log", "w") as f: |
| | for obj in trainer.state.log_history: |
| | f.write(str(obj)) |
| | f.write("\n") |
| | print("- Training complete.") |
| |
|