AIOnTheEdge commited on
Commit
6ff2492
·
verified ·
1 Parent(s): ce10cf8

Create train_acft.py

Browse files
Files changed (1) hide show
  1. train_acft.py +201 -0
train_acft.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from tqdm import tqdm
5
+ from torch import nn
6
+ from datasets import load_dataset, Audio
7
+ from transformers import WhisperModel, WhisperProcessor, get_linear_schedule_with_warmup
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+
14
+ class SlicedEmbedding(nn.Module):
15
+ def __init__(self, orig_embed, n_ctx):
16
+ super().__init__()
17
+ self.orig_embed_ref = [orig_embed]
18
+ self.n_ctx = n_ctx
19
+ self.num_embeddings = n_ctx
20
+
21
+ @property
22
+ def weight(self):
23
+ return self.orig_embed_ref[0].weight[:self.n_ctx]
24
+
25
+ def forward(self, input_ids):
26
+ return self.orig_embed_ref[0](input_ids)
27
+
28
+ def get_sample(example, processor):
29
+ waveform = example["audio"]["array"]
30
+ sampling_rate = example["audio"]["sampling_rate"]
31
+
32
+ input_features = processor(
33
+ waveform, sampling_rate=sampling_rate, return_tensors="pt"
34
+ ).input_features
35
+
36
+ return {
37
+ "length": len(waveform) / sampling_rate,
38
+ "input_features": input_features,
39
+ "input_ids": processor.tokenizer.encode(example["text"].lower())
40
+ }
41
+
42
+ def compute_partially_encoder(model, data, n_audio_ctx):
43
+ diffy = 2*n_audio_ctx - data.shape[2]
44
+
45
+ if diffy > 0:
46
+ data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0)
47
+ elif diffy < 0:
48
+ data = data[:,:,:diffy]
49
+
50
+ if n_audio_ctx == 1500:
51
+ return model.encoder(data).last_hidden_state
52
+
53
+ orig_embed = model.encoder.embed_positions
54
+ orig_max_pos = model.encoder.config.max_source_positions
55
+
56
+ model.encoder.embed_positions = SlicedEmbedding(orig_embed, n_ctx=n_audio_ctx)
57
+ model.encoder.config.max_source_positions = n_audio_ctx
58
+
59
+ try:
60
+ output = model.encoder(data).last_hidden_state
61
+ finally:
62
+ model.encoder.embed_positions = orig_embed
63
+ model.encoder.config.max_source_positions = orig_max_pos
64
+
65
+ return output
66
+
67
+ def compute_hidden_state_loss(model_train, model_base, criterion, example):
68
+ n_ctx = int(round((1500.0 / 30.0) * example["length"] ))
69
+
70
+ assert 0 < n_ctx <= 1500, f"Invalid n_ctx calculated: {n_ctx}"
71
+
72
+ extra_ctx = torch.randint(-min(64, n_ctx // 3), min(64, n_ctx // 3), (1,)).item()
73
+ n_ctx += extra_ctx
74
+ n_ctx = max(1, min(1500, n_ctx))
75
+
76
+ input_features = example["input_features"].cuda()
77
+ input_ids = torch.tensor([example["input_ids"]], dtype=torch.long).cuda()
78
+
79
+ encoder_hidden_states_partial = compute_partially_encoder(model_train, input_features, n_ctx)
80
+
81
+ output_partial = model_train.decoder(
82
+ input_ids=input_ids,
83
+ encoder_hidden_states=encoder_hidden_states_partial,
84
+ output_hidden_states=True
85
+ )
86
+
87
+ with torch.no_grad():
88
+ encoder_hidden_states_full = compute_partially_encoder(model_base, input_features, 1500)
89
+ output_full = model_base.decoder(
90
+ input_ids=input_ids,
91
+ encoder_hidden_states=encoder_hidden_states_full,
92
+ output_hidden_states=True
93
+ )
94
+
95
+ student_tensors = torch.cat(output_partial.hidden_states, 0)
96
+ teacher_tensors = torch.cat(output_full.hidden_states, 0)
97
+
98
+ loss = criterion(student_tensors, teacher_tensors)
99
+ loss.backward()
100
+
101
+ return loss.item()
102
+
103
+ def save_checkpoint(model_train, size, processor, output_dir):
104
+ from transformers import WhisperForConditionalGeneration
105
+ final_model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{size}").eval().cpu()
106
+ final_model.model = model_train.eval().cpu()
107
+ final_model.save_pretrained(output_dir)
108
+ processor.save_pretrained(output_dir)
109
+ model_train.cuda().train()
110
+
111
+ def train_futo_script(size):
112
+ print(f"Starting exact FUTO distillation for model: {size}")
113
+ param_counts = {"tiny": "39M", "base": "74M", "small": "244M"}
114
+
115
+ model_train = WhisperModel.from_pretrained(f"openai/whisper-{size}").cuda().train()
116
+ model_base = WhisperModel.from_pretrained(f"openai/whisper-{size}").cuda().eval()
117
+
118
+ processor = WhisperProcessor.from_pretrained(f"openai/whisper-small", language="danish", task="transcribe")
119
+
120
+ ds = load_dataset("CoRal-project/coral-v3", "read_aloud", token=HF_TOKEN, split="train", streaming=True)
121
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
122
+
123
+ criterion = torch.nn.MSELoss()
124
+
125
+ # Hyperparameters
126
+ learning_rate = 1e-6
127
+ weight_decay = 0.1
128
+ max_training_steps = 20000
129
+
130
+ optimizer = torch.optim.AdamW(model_train.parameters(), lr=learning_rate, weight_decay=weight_decay)
131
+
132
+ writer = SummaryWriter()
133
+ writer.add_text("name", f"{size} v3")
134
+
135
+ num_length = 0
136
+ step = 0
137
+ running_loss = 0.0
138
+
139
+ best_loss = float('inf')
140
+ patience = 20
141
+ patience_counter = 0
142
+ eval_interval = 500
143
+
144
+ pbar = tqdm(ds)
145
+ try:
146
+ for raw_example in pbar:
147
+ duration = len(raw_example["audio"]["array"]) / 16000.0
148
+ if duration > 29.0:
149
+ continue
150
+
151
+ example = get_sample(raw_example, processor)
152
+
153
+ optimizer.zero_grad()
154
+
155
+ # Compute loss and immediately update (Batch Size 1)
156
+ loss_val = compute_hidden_state_loss(model_train, model_base, criterion, example)
157
+ optimizer.step()
158
+
159
+ step += 1
160
+ num_length += example["length"]
161
+
162
+ # Update EMA loss
163
+ running_loss = loss_val if step == 1 else 0.95 * running_loss + 0.05 * loss_val
164
+
165
+ writer.add_scalar("loss/train", loss_val, step)
166
+ writer.add_scalar("length/train", num_length, step)
167
+
168
+ pbar.set_description(f"Step {step}, Avg Loss: {running_loss:.4f}")
169
+
170
+ # Checkpoint
171
+ if step % eval_interval == 0:
172
+ if running_loss < best_loss:
173
+ best_loss = running_loss
174
+ patience_counter = 0
175
+ checkpoint_dir = f"{size}_{param_counts.get(size, 'unknown')}_danish_whisper_acft_futo_best"
176
+ save_checkpoint(model_train, size, processor, checkpoint_dir)
177
+ tqdm.write(f"\n[Step {step}] New best loss: {best_loss:.4f}. Saved checkpoint to {checkpoint_dir}")
178
+ else:
179
+ patience_counter += 1
180
+ tqdm.write(f"\n[Step {step}] No improvement. Patience: {patience_counter}/{patience}")
181
+ if patience_counter >= patience:
182
+ tqdm.write("\n[Early Stopping] Loss hasn't improved. Halting training.")
183
+ break
184
+
185
+ if step >= max_training_steps:
186
+ tqdm.write("\n[Max Steps Reached] Halting training.")
187
+ break
188
+
189
+ except KeyboardInterrupt:
190
+ print("\n\n[CTRL+C detected] Training manually interrupted! Proceeding to save the final model...")
191
+
192
+ output_dir = f"{size}_{param_counts.get(size, 'unknown')}_danish_whisper_acft_futo_latest"
193
+ print(f"\nSaving latest model to {output_dir}")
194
+ save_checkpoint(model_train, size, processor, output_dir)
195
+
196
+ if __name__ == "__main__":
197
+ parser = argparse.ArgumentParser(description="Run exact FUTO script structure.")
198
+ parser.add_argument("--size", choices=["tiny", "base", "small"], default="base")
199
+ args = parser.parse_args()
200
+
201
+ train_futo_script(args.size)