Javad Taghia commited on
Commit
2abe5d0
·
1 Parent(s): 949500b

we have added more wandb stuff

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. train_tulu.py +35 -0
README.md CHANGED
@@ -60,6 +60,7 @@ Key flags:
60
  - `train_tulu.py` loads `.env`, logs into W&B, and reports through `Trainer(report_to=["wandb"])`.
61
  - Ensure `WANDB_API_KEY`, `WANDB_PROJECT`, and (optionally) `WANDB_ENTITY` are set in `.env`.
62
  - Each run captures hyperparameters and metrics; check the W&B UI for live loss curves and checkpoints.
 
63
 
64
  ## Model cache location
65
  - Base model weights download to the Hugging Face cache. You can point downloads to an external directory by setting `BASE_MODEL_CACHE` in `.env` (e.g., `/Volumes/JTQ-s/______GITLAB____/downloaded_base_models`); the script maps this to `HF_HOME`/`TRANSFORMERS_CACHE` before loading models.
 
60
  - `train_tulu.py` loads `.env`, logs into W&B, and reports through `Trainer(report_to=["wandb"])`.
61
  - Ensure `WANDB_API_KEY`, `WANDB_PROJECT`, and (optionally) `WANDB_ENTITY` are set in `.env`.
62
  - Each run captures hyperparameters and metrics; check the W&B UI for live loss curves and checkpoints.
63
+ - Additional summaries are logged: `train_duration_seconds`, `train_examples`, `estimated_tokens`, `precision_mode` (bf16/fp16/fp32), `use_4bit`, `model_name`, `dataset_name`, `per_device_batch_size`, `gradient_accumulation_steps`, and `max_seq_length`.
64
 
65
  ## Model cache location
66
  - Base model weights download to the Hugging Face cache. You can point downloads to an external directory by setting `BASE_MODEL_CACHE` in `.env` (e.g., `/Volumes/JTQ-s/______GITLAB____/downloaded_base_models`); the script maps this to `HF_HOME`/`TRANSFORMERS_CACHE` before loading models.
train_tulu.py CHANGED
@@ -9,6 +9,7 @@ from __future__ import annotations
9
 
10
  import argparse
11
  import os
 
12
  from dataclasses import dataclass
13
  from typing import Dict, List
14
 
@@ -138,22 +139,47 @@ def configure_cache_from_env():
138
 
139
  def main():
140
  load_dotenv()
 
141
  configure_cache_from_env()
 
142
  cfg = parse_args()
 
143
 
144
  init_wandb(cfg)
 
145
  model, tokenizer = load_model_and_tokenizer(cfg)
 
146
 
147
  use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
148
  use_fp16 = torch.cuda.is_available() and not use_bf16
 
 
149
 
150
  raw_dataset = load_dataset(cfg.dataset_name)
 
151
  tokenized = raw_dataset["train"].map(
152
  lambda ex: tokenize_example(ex, tokenizer, cfg.max_seq_length),
153
  remove_columns=raw_dataset["train"].column_names,
154
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 
157
 
158
  training_args = TrainingArguments(
159
  output_dir=cfg.output_dir,
@@ -169,6 +195,7 @@ def main():
169
  report_to=["wandb"],
170
  optim="paged_adamw_32bit",
171
  )
 
172
 
173
  trainer = Trainer(
174
  model=model,
@@ -177,11 +204,19 @@ def main():
177
  tokenizer=tokenizer,
178
  data_collator=data_collator,
179
  )
 
180
 
 
181
  trainer.train()
 
 
 
 
182
  trainer.save_model(cfg.output_dir)
183
  tokenizer.save_pretrained(cfg.output_dir)
 
184
  wandb.finish()
 
185
 
186
 
187
  if __name__ == "__main__":
 
9
 
10
  import argparse
11
  import os
12
+ import time
13
  from dataclasses import dataclass
14
  from typing import Dict, List
15
 
 
139
 
140
  def main():
141
  load_dotenv()
142
+ # Load env vars (WANDB keys, optional cache path).
143
  configure_cache_from_env()
144
+ # Redirect HF cache if BASE_MODEL_CACHE is set.
145
  cfg = parse_args()
146
+ # Read CLI hyperparameters/model settings.
147
 
148
  init_wandb(cfg)
149
+ # Start a W&B run with config and login.
150
  model, tokenizer = load_model_and_tokenizer(cfg)
151
+ # Load base model + tokenizer with LoRA (and 4-bit if enabled).
152
 
153
  use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
154
  use_fp16 = torch.cuda.is_available() and not use_bf16
155
+ # Choose best available mixed precision (bf16 > fp16 > fp32).
156
+ precision_mode = "bf16" if use_bf16 else "fp16" if use_fp16 else "fp32"
157
 
158
  raw_dataset = load_dataset(cfg.dataset_name)
159
+ # Download/load the instruction dataset.
160
  tokenized = raw_dataset["train"].map(
161
  lambda ex: tokenize_example(ex, tokenizer, cfg.max_seq_length),
162
  remove_columns=raw_dataset["train"].column_names,
163
  )
164
+ # Format/tokenize dataset to fixed length with labels.
165
+ train_examples = len(tokenized)
166
+ total_tokens = train_examples * cfg.max_seq_length
167
+ wandb.summary.update(
168
+ {
169
+ "train_examples": train_examples,
170
+ "estimated_tokens": total_tokens,
171
+ "precision_mode": precision_mode,
172
+ "use_4bit": cfg.use_4bit,
173
+ "model_name": cfg.model_name,
174
+ "dataset_name": cfg.dataset_name,
175
+ "per_device_batch_size": cfg.per_device_batch_size,
176
+ "gradient_accumulation_steps": cfg.gradient_accumulation_steps,
177
+ "max_seq_length": cfg.max_seq_length,
178
+ }
179
+ )
180
 
181
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
182
+ # Pad/batch causal LM examples.
183
 
184
  training_args = TrainingArguments(
185
  output_dir=cfg.output_dir,
 
195
  report_to=["wandb"],
196
  optim="paged_adamw_32bit",
197
  )
198
+ # Trainer configuration (logging, saving, optimizer, precision).
199
 
200
  trainer = Trainer(
201
  model=model,
 
204
  tokenizer=tokenizer,
205
  data_collator=data_collator,
206
  )
207
+ # Wire model, data, and config into HF Trainer.
208
 
209
+ train_start = time.time()
210
  trainer.train()
211
+ # Run supervised finetuning (cross-entropy).
212
+ train_duration = time.time() - train_start
213
+ wandb.log({"train_duration_seconds": train_duration})
214
+ # Record wall-clock training time to W&B.
215
  trainer.save_model(cfg.output_dir)
216
  tokenizer.save_pretrained(cfg.output_dir)
217
+ # Save adapters/tokenizer to output_dir.
218
  wandb.finish()
219
+ # Close W&B run.
220
 
221
 
222
  if __name__ == "__main__":