indiejoseph commited on
Commit
296ac4c
·
verified ·
1 Parent(s): 54435c2

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +162 -0
train.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import Trainer, TrainingArguments, Wav2Vec2CTCTokenizer
4
+ import torch.nn.functional as F
5
+ from models.ctc_model import CTCTransformerModel, CTCTransformerConfig
6
+ from data import DataCollatorCTCWithPadding, SpeechTokenPhonemeDataset
7
+ import evaluate
8
+ import numpy as np
9
+ import pandas as pd
10
+ import logging
11
+ import warnings
12
+
13
+ os.environ["WANDB_PROJECT"] = "speech-phoneme-ctc"
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ df = pd.read_csv(
20
+ "dataset.csv",
21
+ )
22
+
23
+ # Dataset
24
+ vocab_path = "vocab/vocab.json"
25
+ tokenizer = Wav2Vec2CTCTokenizer(
26
+ vocab_path,
27
+ unk_token="[UNK]",
28
+ pad_token="[PAD]",
29
+ word_delimiter_token="|",
30
+ )
31
+ vocab = tokenizer.get_vocab()
32
+ vocab_inv = {v: k for k, v in vocab.items()}
33
+ num_speech_tokens = 6561
34
+
35
+ # ===== MODEL SETUP =====
36
+ config = CTCTransformerConfig(
37
+ vocab_size=num_speech_tokens,
38
+ num_labels=len(tokenizer),
39
+ hidden_size=768,
40
+ intermediate_size=3072,
41
+ num_attention_heads=12,
42
+ num_hidden_layers=12,
43
+ max_position_embeddings=1024,
44
+ label2id=vocab,
45
+ id2label=vocab_inv,
46
+ pad_token_id=tokenizer.pad_token_id, # output padding token
47
+ src_pad_token_id=num_speech_tokens, # input padding token
48
+ )
49
+ model = CTCTransformerModel(config)
50
+ dataset = SpeechTokenPhonemeDataset(df, tokenizer=tokenizer)
51
+ train_valid_dataset = dataset.train_test_split(test_size=0.05, random_state=42)
52
+ train_dataset = train_valid_dataset["train"]
53
+ eval_dataset = train_valid_dataset["test"]
54
+ collator = DataCollatorCTCWithPadding(
55
+ pad_token_id=num_speech_tokens, label_pad_token_id=tokenizer.pad_token_id
56
+ )
57
+
58
+ # ===== METRICS =====
59
+ cer_metric = evaluate.load("cer")
60
+
61
+
62
+ def compute_metrics(pred):
63
+ label_ids = pred.label_ids
64
+ logits = pred.predictions
65
+ log_probs = F.log_softmax(torch.tensor(logits), dim=-1)
66
+ pred_ids = np.argmax(log_probs, axis=-1)
67
+
68
+ # Replace -100 with pad token ID
69
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
70
+
71
+ # Decode predictions and references
72
+ pred_str = tokenizer.batch_decode(pred_ids)
73
+ label_str = tokenizer.batch_decode(label_ids, group_tokens=False)
74
+
75
+ # Calculate WER and CER
76
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
77
+
78
+ return {"cer": cer}
79
+
80
+
81
+ # Check vocabulary compatibility and print more detailed diagnostic info
82
+ print(f"Model vocab size: {model.config.vocab_size}")
83
+ print(f"Tokenizer vocab size: {len(tokenizer)}")
84
+ print(
85
+ f"Vocabulary: {list(tokenizer.get_vocab().keys())[:10]}... (showing first 10 tokens)"
86
+ )
87
+ print("Training dataset size:", len(train_dataset))
88
+ print("Evaluation dataset size:", len(eval_dataset))
89
+
90
+ if model.config.vocab_size != len(tokenizer.get_vocab()):
91
+ print("WARNING: Vocabulary size mismatch between model and tokenizer!")
92
+
93
+
94
+ training_args = TrainingArguments(
95
+ output_dir="./results",
96
+ per_device_train_batch_size=64,
97
+ per_device_eval_batch_size=16,
98
+ eval_strategy="epoch",
99
+ save_strategy="epoch",
100
+ save_total_limit=10,
101
+ num_train_epochs=10,
102
+ learning_rate=1e-4,
103
+ weight_decay=0.005,
104
+ warmup_ratio=0.1,
105
+ logging_steps=100,
106
+ logging_dir="./logs",
107
+ gradient_accumulation_steps=1,
108
+ bf16=True,
109
+ report_to="wandb",
110
+ remove_unused_columns=False,
111
+ dataloader_num_workers=4,
112
+ include_inputs_for_metrics=True,
113
+ )
114
+
115
+ trainer = Trainer(
116
+ model=model,
117
+ args=training_args,
118
+ train_dataset=train_dataset,
119
+ eval_dataset=eval_dataset,
120
+ data_collator=collator,
121
+ compute_metrics=compute_metrics,
122
+ )
123
+
124
+ logger.info("***** Running training *****")
125
+ logger.info(f" Num examples = {len(train_dataset)}")
126
+ logger.info(f" Num Epochs = {training_args.num_train_epochs}")
127
+ logger.info(
128
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
129
+ )
130
+ logger.info(
131
+ f" Total train batch size (w. parallel, distributed & accumulation) = {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
132
+ )
133
+ logger.info(
134
+ f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}"
135
+ )
136
+ logger.info(f" Total optimization steps = {training_args.max_steps}")
137
+ logger.info(f" Logging steps = {training_args.logging_steps}")
138
+ logger.info(f" Learning rate = {training_args.learning_rate}")
139
+ logger.info(f" Weight decay = {training_args.weight_decay}")
140
+ logger.info(f" Warmup steps = {training_args.warmup_steps}")
141
+ logger.info(f" Save total limit = {training_args.save_total_limit}")
142
+
143
+ train_res = trainer.train()
144
+
145
+ trainer.save_model()
146
+ trainer.save_state()
147
+
148
+ metrics = train_res.metrics
149
+ metrics["train_samples"] = len(train_dataset)
150
+ trainer.log_metrics("train", metrics)
151
+ trainer.save_metrics("train", metrics)
152
+
153
+ metrics = trainer.evaluate()
154
+ metrics["eval_samples"] = len(eval_dataset)
155
+ trainer.log_metrics("eval", metrics)
156
+ trainer.save_metrics("eval", metrics)
157
+
158
+ with open("results/train.log", "w") as f:
159
+ for obj in trainer.state.log_history:
160
+ f.write(str(obj))
161
+ f.write("\n")
162
+ print("- Training complete.")