Mayank022 commited on
Commit
b2ac71b
·
verified ·
1 Parent(s): fcaae5c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +257 -0
train.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dataclasses
3
+ import torch
4
+ import transformers
5
+ from transformers import Trainer, TrainingArguments, TrainerCallback
6
+ from peft import LoraConfig, get_peft_model, TaskType
7
+ from huggingface_hub import HfApi, login
8
+ import wandb
9
+ from dotenv import load_dotenv
10
+ from config import TrainConfig, ModelConfig
11
+ from model import MultiModalModel
12
+ from data import AudioTextDataset, DataCollator
13
+
14
+
15
+ class SamplePredictionCallback(TrainerCallback):
16
+ """Every N steps, print ground-truth vs model-predicted transcript for a few samples."""
17
+
18
+ def __init__(self, tokenizer, data_collator, train_dataset, sample_every_n_steps: int = 100, num_samples: int = 2, prompt: str = "Transcribe the following audio:"):
19
+ self.tokenizer = tokenizer
20
+ self.data_collator = data_collator
21
+ self.train_dataset = train_dataset
22
+ self.sample_every_n_steps = sample_every_n_steps
23
+ self.num_samples = num_samples
24
+ self.prompt = prompt
25
+ def on_log(self, args, state, control, model=None, **kwargs):
26
+ if state.global_step == 0 or state.global_step % self.sample_every_n_steps != 0:
27
+ return
28
+ if model is None:
29
+ return
30
+ model.eval()
31
+ device = next(model.parameters()).device
32
+ try:
33
+ indices = [i % len(self.train_dataset) for i in range(self.num_samples)]
34
+ batch = self.data_collator([self.train_dataset[i] for i in indices])
35
+ audio_values = batch["audio_values"].to(device)
36
+ labels_batch = batch["labels"]
37
+ continuations = batch.get("continuation", [""] * audio_values.size(0))
38
+ prompt_ids = self.tokenizer(self.prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device)
39
+ prompt_ids = prompt_ids.expand(audio_values.size(0), -1)
40
+ with torch.no_grad():
41
+ gen_ids = model.generate(
42
+ input_ids=prompt_ids,
43
+ audio_values=audio_values,
44
+ max_new_tokens=120,
45
+ do_sample=False,
46
+ pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
47
+ )
48
+ prompt_len = prompt_ids.size(1)
49
+
50
+ # Create a wandb Table
51
+ columns = ["Step", "Audio Index", "Ground Truth", "Prediction", "Continuation"]
52
+ table = wandb.Table(columns=columns)
53
+
54
+ print(f"\n[WandB] Logging sample predictions at step {state.global_step}")
55
+
56
+ for i in range(audio_values.size(0)):
57
+ gt_tokens = [t for t in labels_batch[i].tolist() if t != -100]
58
+ gt_text = self.tokenizer.decode(gt_tokens, skip_special_tokens=True).strip()
59
+ pred_text = self.tokenizer.decode(gen_ids[i][prompt_len:], skip_special_tokens=True).strip()
60
+
61
+ cont_ref = continuations[i] if i < len(continuations) else ""
62
+
63
+ # Add row to table
64
+ table.add_data(state.global_step, i, gt_text, pred_text, cont_ref)
65
+
66
+ # Log the table to wandb
67
+ if wandb.run is not None:
68
+ wandb.log({"sample_predictions": table}, step=state.global_step)
69
+ else:
70
+ print("Warning: WandB run not active, skipping logging.")
71
+
72
+ except Exception as e:
73
+ print(f"[SamplePredictionCallback] Error: {e}\n")
74
+ finally:
75
+ model.train()
76
+
77
+
78
+ import shutil
79
+ import glob
80
+ from transformers.trainer_utils import get_last_checkpoint
81
+
82
+ class AggressiveDeleteCallback(TrainerCallback):
83
+ """
84
+ Deletes ALL existing checkpoints in output_dir *before* saving a new one
85
+ to ensure we don't run out of disk space.
86
+ Only keeps the one we are currently training on (in memory) effectively,
87
+ but on disk we want 0 checkpoints just before save.
88
+
89
+ WARNING: If save fails, we have NO checkpoints on disk. Risk accepted by user.
90
+ """
91
+ def __init__(self, output_dir):
92
+ self.output_dir = output_dir
93
+
94
+ def on_step_end(self, args, state, control, **kwargs):
95
+ # Check if we are about to save
96
+ # Trainer checks: if save_strategy == "steps" and global_step % save_steps == 0
97
+ if args.save_strategy == "steps" and args.save_steps > 0:
98
+ if state.global_step > 0 and state.global_step % args.save_steps == 0:
99
+ # We are about to save. Delete old checkpoints.
100
+ print(f"\n[AggressiveDeleteCallback] Step {state.global_step}: Deleting old checkpoints to free space before saving...")
101
+
102
+ # Verify we aren't deleting something we just wrote (unlikely in on_step_end, save happens after)
103
+ # But we might be resuming.
104
+
105
+ ckpts = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
106
+ for ckpt in ckpts:
107
+ try:
108
+ shutil.rmtree(ckpt)
109
+ print(f" Deleted {ckpt}")
110
+ except Exception as e:
111
+ print(f" Failed to delete {ckpt}: {e}")
112
+
113
+ def train():
114
+ # Load environment variables
115
+ load_dotenv()
116
+
117
+ # Load Configs
118
+ train_config = TrainConfig()
119
+ model_config = ModelConfig()
120
+
121
+ # Initialize WandB
122
+ wandb.init(
123
+ project=train_config.wandb_project,
124
+ entity=train_config.wandb_entity,
125
+ name=train_config.wandb_run_name,
126
+ config=dataclasses.asdict(train_config),
127
+ )
128
+
129
+
130
+ # Initialize Tokenizer & Processor
131
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.text_model_id)
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+
135
+ processor = transformers.AutoProcessor.from_pretrained(model_config.audio_model_id)
136
+
137
+ # Initialize Model
138
+ model = MultiModalModel(model_config)
139
+
140
+ # Apply LoRA if requested
141
+ if train_config.use_lora:
142
+ peft_config = LoraConfig(
143
+ task_type=TaskType.CAUSAL_LM,
144
+ inference_mode=False,
145
+ r=train_config.lora_r,
146
+ lora_alpha=train_config.lora_alpha,
147
+ lora_dropout=train_config.lora_dropout,
148
+ target_modules=["q_proj", "v_proj"]
149
+ )
150
+ model.llm = get_peft_model(model.llm, peft_config)
151
+ model.llm.print_trainable_parameters()
152
+
153
+ # Dataset
154
+ train_dataset = AudioTextDataset(train_config, processor, model_config, tokenizer)
155
+ data_collator = DataCollator(processor, tokenizer)
156
+
157
+ # Training Arguments (tuned for A100 80GB: bf16, larger batch, fast dataloader)
158
+ training_args = TrainingArguments(
159
+ output_dir=train_config.output_dir,
160
+ per_device_train_batch_size=train_config.batch_size,
161
+ gradient_accumulation_steps=train_config.accum_steps,
162
+ learning_rate=train_config.learning_rate,
163
+ lr_scheduler_type=train_config.lr_scheduler_type,
164
+ num_train_epochs=train_config.num_epochs,
165
+ max_steps=train_config.max_steps,
166
+ bf16=train_config.use_bf16,
167
+ gradient_checkpointing=train_config.gradient_checkpointing,
168
+ dataloader_num_workers=train_config.dataloader_num_workers,
169
+ dataloader_pin_memory=train_config.dataloader_pin_memory,
170
+ logging_steps=train_config.log_steps,
171
+ logging_first_step=True,
172
+ logging_nan_inf_filter=True,
173
+ save_steps=train_config.save_steps,
174
+ save_total_limit=train_config.save_total_limit,
175
+ eval_strategy="no", # change if val set provided
176
+ remove_unused_columns=False, # Important because we have custom forward signature
177
+ report_to="wandb",
178
+ log_level="info",
179
+ log_level_replica="info",
180
+ )
181
+
182
+ sample_callback = SamplePredictionCallback(
183
+ tokenizer=tokenizer,
184
+ data_collator=data_collator,
185
+ train_dataset=train_dataset,
186
+ sample_every_n_steps=train_config.sample_pred_every_steps,
187
+ num_samples=2,
188
+ prompt="Transcribe the following audio:",
189
+ )
190
+
191
+ aggressive_delete_callback = AggressiveDeleteCallback(train_config.output_dir)
192
+
193
+ trainer = Trainer(
194
+ model=model,
195
+ args=training_args,
196
+ train_dataset=train_dataset,
197
+ data_collator=data_collator,
198
+ callbacks=[sample_callback, aggressive_delete_callback],
199
+ )
200
+
201
+ total_steps = train_config.max_steps
202
+ print(f"\n>>> Training: max_steps={total_steps}, batch_size={train_config.batch_size}, "
203
+ f"grad_accum={train_config.accum_steps} (effective batch={train_config.batch_size * train_config.accum_steps})")
204
+ print(f">>> Sample predictions (GT vs predicted transcript) every {train_config.sample_pred_every_steps} steps.\n")
205
+
206
+ # Resume from checkpoint if exists
207
+ last_checkpoint = get_last_checkpoint(train_config.output_dir)
208
+ if last_checkpoint is not None:
209
+ print(f">>> Resuming from checkpoint: {last_checkpoint}")
210
+ trainer.train(resume_from_checkpoint=last_checkpoint)
211
+ else:
212
+ trainer.train()
213
+
214
+ # Save
215
+ trainer.save_model(train_config.output_dir)
216
+ tokenizer.save_pretrained(train_config.output_dir)
217
+ processor.save_pretrained(train_config.output_dir)
218
+
219
+ # Push to Hub
220
+ if train_config.push_to_hub:
221
+ print(f"\n>>> Pushing model to Hugging Face Hub: {train_config.hub_model_id}")
222
+ if train_config.hub_token:
223
+ login(token=train_config.hub_token)
224
+
225
+ api = HfApi()
226
+
227
+ # Create repo if needed
228
+ # private=True by default for safety, user can adjust
229
+ try:
230
+ api.create_repo(repo_id=train_config.hub_model_id, private=train_config.hub_private_repo, exist_ok=True)
231
+ except Exception as e:
232
+ print(f"Warning: Could not create repo {train_config.hub_model_id}. Error: {e}")
233
+
234
+ # Upload model folder
235
+ try:
236
+ api.upload_folder(
237
+ folder_path=train_config.output_dir,
238
+ repo_id=train_config.hub_model_id,
239
+ repo_type="model",
240
+ )
241
+
242
+ # Upload code files to ensure custom model works
243
+ for file in ["model.py", "config.py", "data.py", "inference.py"]:
244
+ if os.path.exists(file):
245
+ api.upload_file(
246
+ path_or_fileobj=file,
247
+ path_in_repo=file,
248
+ repo_id=train_config.hub_model_id,
249
+ repo_type="model",
250
+ )
251
+
252
+ print(f">>> Successfully pushed to {train_config.hub_model_id}")
253
+ except Exception as e:
254
+ print(f"Error pushing to hub: {e}")
255
+
256
+ if __name__ == "__main__":
257
+ train()