craffel/moto_checkpoints / script_1 /code /apps /main /train_answer_only.py
craffel's picture
download
raw
33.6 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
import gc
import json
import logging
import math
import os
from contextlib import ExitStack
from dataclasses import dataclass, field
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.distributed._tensor import DTensor
from apps.main.train import TrainState, every_n_steps, set_preemption_flag
from apps.main.transformer import (
LMTransformer,
LMTransformerArgs,
build_fsdp_grouping_plan,
get_no_recompute_ops,
get_num_flop_per_token,
tp_parallelize,
)
from lingua.args import dataclass_from_dict, dump_config, flatten_dict
from lingua.checkpoint import (
CheckpointArgs,
CheckpointManager,
consolidate_checkpoints,
load_from_checkpoint,
)
from lingua.data import TRAIN_DATA_FILE_PATTERN, distribute_data_to_rank, loop_on_jsonl
from lingua.distributed import (
DistributedArgs,
EnvironmentArgs,
check_model_value_range,
dist_mean_dict,
get_device_mesh,
get_global_rank,
get_is_master,
get_world_size,
init_signal_handler,
parallelize_model,
setup_env,
setup_torch_distributed,
)
from lingua.logger import init_logger
from lingua.metrics import LoggingArgs, MetricLogger, get_num_params
from lingua.optim import OptimArgs, build_optimizer
from lingua.profiling import ProfilerArgs
from lingua.tokenizer import TokenizerArgs, build_token_bytes, build_tokenizer
logger = logging.getLogger()
DUMP_DOCS = os.environ.get("DUMP_DOCS", "False") == "True"
DUMP_DOCS_MAX_SAMPLES = int(os.environ.get("DUMP_DOCS_MAX_SAMPLES", "2"))
# DUMP_DIR = os.environ.get("DUMP_DIR", "/scratch/gsa/train/dump-mod")
@dataclass
class QADataArgs:
root_dir: Optional[str] = None
sources: Dict[str, float] = field(default_factory=dict)
batch_size: int = 64
seq_len: int = 64
seed: int = 42
add_bos: bool = True
add_eos: bool = True
tokenizer: TokenizerArgs = field(default_factory=TokenizerArgs)
file_pattern: str = TRAIN_DATA_FILE_PATTERN
text_key: str = "text"
question_key: str = "question"
answer_key: str = "answer"
n_views: int = 2
prefetch_size: int = 64
load_async: bool = True
suitable_tokenizer_key: str = "suitable_tokenizer"
suitable_tokenizer_probability: float = 0.0
suitable_tokenizer_map: Dict[str, str] = field(default_factory=dict)
@dataclass
class TrainAnswerOnlyArgs:
name: str = "lingua-answer-only"
dump_dir: str = ""
seed: int = 42
grad_acc_steps: int = 1
steps: int = 1000
data: QADataArgs = field(default_factory=QADataArgs)
optim: OptimArgs = field(default_factory=OptimArgs)
model: LMTransformerArgs = field(default_factory=LMTransformerArgs)
distributed: DistributedArgs = field(default_factory=DistributedArgs)
env: EnvironmentArgs = field(default_factory=EnvironmentArgs)
checkpoint: CheckpointArgs = field(default_factory=CheckpointArgs)
profiling: ProfilerArgs = field(default_factory=ProfilerArgs)
logging: LoggingArgs = field(default_factory=LoggingArgs)
track_source_metrics: bool = False
gc_collect_freq: int = 1000
probe_freq: Optional[int] = None
async_eval_gpus: Optional[int] = None
eval: Optional[Any] = None
def _source_iterators(args: TrainAnswerOnlyArgs):
rank = get_global_rank()
world_size = get_world_size()
per_source = {}
for source in args.data.sources:
source_path = os.path.join(args.data.root_dir, source)
state = distribute_data_to_rank(
source_path,
rank=rank,
world_size=world_size,
file_pattern=args.data.file_pattern,
)
per_source[source] = loop_on_jsonl(
state["file_path"],
state["position"],
state["block_size"],
state["offset"],
state["current_iter"],
)
return per_source
def _normalize_weights(sources: Dict[str, float]) -> np.ndarray:
weights = np.array([float(v) for v in sources.values()], dtype=np.float64)
weights = weights / weights.sum()
return weights
def _to_list(token_ids):
if isinstance(token_ids, np.ndarray):
return token_ids.tolist()
return list(token_ids)
def _build_example(
row: Dict[str, Any],
tokenizer,
seq_len: int,
add_bos: bool,
add_eos: bool,
text_key: str,
question_key: str,
answer_key: str,
tokenizer_choice: Optional[int] = None,
):
full_text = row.get(text_key)
if full_text is None:
full_text = row.get("content")
question = row.get(question_key)
answer = row.get(answer_key)
if full_text is None:
return None
if question is not None and answer is not None:
question_text = str(question)
answer_text = str(answer)
full_text = f"{question_text}{answer_text}"
prompt_text = question_text
else:
if question is None:
parts = str(full_text).rsplit(" ", 1)
question = (parts[0] + " ") if len(parts) > 1 else str(full_text)
prompt_text = str(question)
encode_kwargs: Dict[str, Any] = {"add_bos": add_bos, "add_eos": add_eos}
prompt_encode_kwargs: Dict[str, Any] = {"add_bos": add_bos, "add_eos": False}
if tokenizer_choice is not None:
encode_kwargs["tokenizer_choice"] = tokenizer_choice
prompt_encode_kwargs["tokenizer_choice"] = tokenizer_choice
full_ids = _to_list(tokenizer.encode(str(full_text), **encode_kwargs))
prompt_ids = _to_list(tokenizer.encode(prompt_text, **prompt_encode_kwargs))
full_ids = full_ids[: seq_len + 1]
if len(full_ids) < 2:
return None
input_ids = full_ids[:-1]
labels = full_ids[1:]
prompt_target_count = max(0, min(len(labels), len(prompt_ids) - 1))
for i in range(prompt_target_count):
labels[i] = -100
# import code; code.interact(local=locals()|globals())
pad_id = getattr(tokenizer, "pad_id", None)
if pad_id is None:
pad_id = getattr(tokenizer, "eos_id", 0)
pad_id = int(pad_id)
if len(input_ids) < seq_len:
pad_n = seq_len - len(input_ids)
input_ids = input_ids + [pad_id] * pad_n
labels = labels + [-100] * pad_n
else:
input_ids = input_ids[:seq_len]
labels = labels[:seq_len]
return input_ids, labels
def _sample_row_tokenizer_choice(
row: Dict[str, Any],
tokenizer,
suitable_tokenizer_key: str,
suitable_tokenizer_probability: float,
suitable_tokenizer_map: Dict[str, str],
) -> Optional[int]:
if not hasattr(tokenizer, "sample_tokenizer"):
return None
preferred_tokenizer = _map_dataset_tokenizer_to_superset_key(
dataset_tokenizer_name=row.get(suitable_tokenizer_key),
tokenizer=tokenizer,
suitable_tokenizer_map=suitable_tokenizer_map,
)
try:
if preferred_tokenizer == "random":
sampled_choice, _ = tokenizer.sample_tokenizer()
else:
sampled_choice, _ = tokenizer.sample_tokenizer(
preferred_tokenizer=preferred_tokenizer,
preferred_probability=suitable_tokenizer_probability,
)
return int(sampled_choice)
except TypeError:
sampled_choice, _ = tokenizer.sample_tokenizer()
return int(sampled_choice)
except Exception:
return None
def _map_dataset_tokenizer_to_superset_key(
dataset_tokenizer_name: Any,
tokenizer,
suitable_tokenizer_map: Dict[str, str],
) -> Optional[str]:
if dataset_tokenizer_name is None:
return None
raw_name = str(dataset_tokenizer_name).strip()
if raw_name == "":
return None
lowered_name = raw_name.lower()
mapped = suitable_tokenizer_map.get(raw_name)
if mapped is None:
mapped = suitable_tokenizer_map.get(lowered_name)
candidate_name = mapped if mapped is not None else raw_name
if not hasattr(tokenizer, "tokenizers") or not isinstance(tokenizer.tokenizers, dict):
return candidate_name
tokenizer_keys = list(tokenizer.tokenizers.keys())
if candidate_name in tokenizer.tokenizers:
return candidate_name
lowered_candidate = candidate_name.lower()
for key in tokenizer_keys:
lowered_key = key.lower()
if lowered_key == lowered_candidate:
return key
for key in tokenizer_keys:
lowered_key = key.lower()
if lowered_key.endswith(f"/{lowered_candidate}"):
return key
if lowered_candidate.endswith(f"/{lowered_key}"):
return key
return candidate_name
def _batch_iterator(args: TrainAnswerOnlyArgs, tokenizer):
source_names = list(args.data.sources.keys())
source_weights = _normalize_weights(args.data.sources)
source_to_id = {name: idx for idx, name in enumerate(source_names)}
source_iters = _source_iterators(args)
rng = np.random.default_rng((args.data.seed, get_global_rank(), get_world_size()))
try:
while True:
batch_inputs = []
batch_labels = []
batch_source_ids = []
while len(batch_inputs) < args.data.batch_size:
source = source_names[rng.choice(len(source_names), p=source_weights)]
row, _ = next(source_iters[source])
tokenizer_choice = _sample_row_tokenizer_choice(
row=row,
tokenizer=tokenizer,
suitable_tokenizer_key=args.data.suitable_tokenizer_key,
suitable_tokenizer_probability=args.data.suitable_tokenizer_probability,
suitable_tokenizer_map=args.data.suitable_tokenizer_map,
)
example = _build_example(
row=row,
tokenizer=tokenizer,
seq_len=args.data.seq_len,
add_bos=args.data.add_bos,
add_eos=args.data.add_eos,
text_key=args.data.text_key,
question_key=args.data.question_key,
answer_key=args.data.answer_key,
tokenizer_choice=tokenizer_choice,
)
if example is None:
continue
x, y = example
batch_inputs.append(x)
batch_labels.append(y)
batch_source_ids.append(source_to_id[source])
yield (
torch.tensor(batch_inputs, dtype=torch.long),
torch.tensor(batch_labels, dtype=torch.long),
torch.tensor(batch_source_ids, dtype=torch.long),
)
finally:
for it in source_iters.values():
it.close()
def maybe_dump_training_batch(
tokenizer,
input_ids: torch.Tensor,
labels: torch.Tensor,
source_ids: torch.Tensor,
source_names: list[str],
step: int,
acc_step: int,
dump_dir: str,
) -> None:
if not DUMP_DOCS:
return
effective_dump_dir = dump_dir
if not effective_dump_dir:
return
rank = get_global_rank()
dump_path = Path(effective_dump_dir) / "training_docs" / f"rank_{rank}.jsonl"
dump_path.parent.mkdir(parents=True, exist_ok=True)
cpu_input_ids = input_ids.detach().cpu()
cpu_labels = labels.detach().cpu()
cpu_source_ids = source_ids.detach().cpu()
n_samples = min(DUMP_DOCS_MAX_SAMPLES, cpu_input_ids.shape[0])
with open(dump_path, "a") as f_dump:
for sample_idx in range(n_samples):
input_id_list = [x for x in cpu_input_ids[sample_idx].tolist() if x != 2]
label_id_list = [x for x in cpu_labels[sample_idx].tolist() if x != 2 and x != -100]
label_id_valid = [token_id for token_id in label_id_list if token_id != -100]
source_id = int(cpu_source_ids[sample_idx].item())
row = {
"step": int(step),
"acc_step": int(acc_step),
"sample_idx": int(sample_idx),
"source_id": source_id,
"source_name": source_names[source_id] if 0 <= source_id < len(source_names) else None,
"input_ids": input_id_list,
"label_ids": label_id_list,
}
try:
row["input_text"] = tokenizer.decode(input_id_list, skip_special_tokens=False)
except TypeError:
try:
row["input_text"] = tokenizer.decode(input_id_list)
except Exception:
row["input_text"] = None
except Exception:
row["input_text"] = None
try:
row["label_text"] = tokenizer.decode(label_id_valid, skip_special_tokens=False)
except TypeError:
try:
row["label_text"] = tokenizer.decode(label_id_valid)
except Exception:
row["label_text"] = None
except Exception:
row["label_text"] = None
json.dump(row, f_dump)
f_dump.write("\n")
def validate_train_args(args: TrainAnswerOnlyArgs, output_size: int):
if args.model.vocab_size < 0:
logger.info(f"Setting model output size to {output_size}")
args.model.vocab_size = output_size
assert (
args.model.vocab_size == output_size
), "Vocab size should be the same as output size"
assert args.dump_dir, "Dump dir not set"
if args.checkpoint.path is None:
logger.info(f"Setting checkpoint path to {str(Path(args.dump_dir) / 'checkpoints')}")
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
for source in args.data.sources:
data_path = os.path.join(args.data.root_dir, source)
assert os.path.exists(data_path), f"{data_path} doesn't exist"
if (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
!= get_world_size()
):
assert get_world_size() % args.distributed.dp_shard == 0
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
args.distributed.dp_replicate = (
args.distributed.dp_replicate // args.distributed.tp_size
)
logger.warning(
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
)
assert (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
== get_world_size()
)
if args.distributed.fsdp_type == "no_shard":
assert (
args.distributed.dp_shard == 1
and args.distributed.dp_replicate == get_world_size()
)
args.model.max_seqlen = args.data.seq_len
if args.distributed.tp_size == 1:
logger.warning(
"Tensor parallelism has not been tested for a while, use at your own risk"
)
assert (
args.probe_freq != args.profiling.mem_steps
), "Don't profile during probe step"
assert (
args.probe_freq != args.profiling.profile_steps
), "Don't profile during probe step"
if args.logging.wandb is not None:
args.logging.wandb.name = args.name
if args.probe_freq is not None:
assert (
args.distributed.tp_size == 1
), "Probing not supported with tensor parallelism"
assert (
args.distributed.selective_activation_checkpointing is False
), "Probing not supported with selective activation checkpointing"
preemption_flag = dict(flag=False)
def train(args: TrainAnswerOnlyArgs):
with ExitStack() as context_stack:
assert args.dump_dir, "dump_dir is required"
assert args.data.root_dir is not None, "data.root_dir is required"
assert len(args.data.sources) > 0, "data.sources must be non-empty"
tokenizer = build_tokenizer(
args.data.tokenizer.name,
args.data.tokenizer.path,
args.data.tokenizer.tokenizers,
args.data.tokenizer.dropout,
superset_code_name=args.data.tokenizer.superset_code_name,
n_words=args.data.tokenizer.n_words,
)
validate_train_args(
args,
tokenizer.n_words,
)
if get_is_master():
os.makedirs(args.dump_dir, exist_ok=True)
dump_config(args, Path(args.dump_dir) / "config.yaml")
init_logger(str(Path(args.dump_dir) / "train.log"))
init_signal_handler(set_preemption_flag) # For handling preemption signals.
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
raise RuntimeError(
"No CUDA GPUs are visible before distributed init. "
f"cuda_available={torch.cuda.is_available()} "
f"device_count={torch.cuda.device_count()} "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')} "
f"SLURM_JOB_ID={os.environ.get('SLURM_JOB_ID')} "
f"SLURM_NODELIST={os.environ.get('SLURM_JOB_NODELIST')} "
f"SLURM_NTASKS={os.environ.get('SLURM_NTASKS')}. "
"Ensure the job is launched on a GPU allocation (e.g. launcher=sbatch with ngpu>=1), "
"and verify inside the job with `nvidia-smi`."
)
setup_env(args.env)
setup_torch_distributed(args.distributed)
world_mesh = get_device_mesh(args.distributed)
logger.info(f"Starting job: {args.name}")
# build dataloader
# need dp world size and rank
dp_mesh = world_mesh["dp_replicate"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
source_names = list(args.data.sources.keys())
if args.distributed.dp_shard > 1:
dp_rank = dp_rank * world_mesh["dp_shard"].size() + world_mesh["dp_shard"].get_local_rank()
dp_degree *= world_mesh["dp_shard"].size()
logger.info(f"Running on dp rank : {dp_rank}")
logger.info(f"Running on dp size : {dp_degree}")
torch.manual_seed(args.seed)
logger.info("Building model")
if args.model.vocab_size < 0:
args.model.vocab_size = tokenizer.n_words
assert args.model.vocab_size == tokenizer.n_words, "model.vocab_size must match tokenizer.n_words"
args.model.max_seqlen = args.data.seq_len
with torch.device("meta"):
model = LMTransformer(args.model)
logger.info("Model is built !")
model_param_count = get_num_params(model)
model = parallelize_model(
model,
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
tp_parallelize=tp_parallelize,
no_recompute_ops=get_no_recompute_ops(),
)
# Once we shard the model on different gpus we can actually initialize the model
# First we create empty tensors of the correct shapes
model = model.to_empty(device="cuda")
# Then we init the model. Please make sure this function initializes *ALL* parameters
# and buffers, otherwise you will have random values in the unitialized tensors
# which will silently fail (give nan gradients for example)
# log model size
logger.info(f"Model size: {model_param_count:,} total parameters")
optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
use_bf16_autocast = str(args.distributed.model_dtype).lower() in {"bf16", "bfloat16"}
if args.checkpoint.init_ckpt_path:
# todo: maybe auto load the largest ckpt
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
if args.checkpoint.load_init_optimizer_state:
load_from_checkpoint(args.checkpoint.init_ckpt_path, model, optimizer, model_key="model") # Put model_key="" if its directly the model checkpoint
else:
load_from_checkpoint(args.checkpoint.init_ckpt_path, model, model_key="model") # Put model_key="" if its directly the model checkpoint
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
else:
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(args.model.seed)
model.init_weights()
check_model_value_range(model, range=10.0, std=1.0)
data_iter = _batch_iterator(args, tokenizer)
token_bytes_dict = build_token_bytes(tokenizer, tokenizer.n_words)
token_bytes_tensor = torch.zeros(tokenizer.n_words, dtype=torch.int64, device="cuda")
for tid, nb in token_bytes_dict.items():
token_bytes_tensor[tid] = nb
data_loader_state = {
"start_token": 0,
"it_state": {},
"output_seq_len": args.data.seq_len,
"n_views": args.data.n_views,
"seq_len": 0,
}
log_freq = 10
if args.logging is not None and getattr(args.logging, "freq", None) is not None:
log_freq = int(args.logging.freq)
train_state = TrainState(
step=0,
acc_step=0,
data_loader_state=data_loader_state,
scheduler=scheduler,
)
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
checkpoint.load(model, optimizer, train_state, world_mesh)
if args.checkpoint.save_init_ckpt:
if checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
):
_ = consolidate_checkpoints(str(checkpoint.existing_saves[-1]))
n_tokens = 0
t_last = timer()
gc.disable()
model.train()
metric_logger = context_stack.enter_context(
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
)
nwords_since_last_log = 0
time_last_log = timer()
gc.collect()
saved = False
while train_state.step < args.steps:
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
train_state.acc_step += 1
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
# get batch
curr_lr = float(optimizer.param_groups[0]["lr"])
data_load_start = timer()
input_ids, labels, source_ids, token_bytes = next(data_iter)
maybe_dump_training_batch(
tokenizer,
input_ids,
labels,
source_ids,
source_names,
train_state.step,
train_state.acc_step,
args.dump_dir,
)
data_load_time = round(timer() - data_load_start, 4)
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
logger.info("garbage collection")
# we do garbage collection manually otherwise different processes
# run the GC at different times so they slow down the whole pipeline
gc.collect()
input_ids = input_ids.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
source_ids = source_ids.cuda(non_blocking=True)
# forward
start_timer = torch.cuda.Event(enable_timing=True)
end_timer = torch.cuda.Event(enable_timing=True)
start_timer.record()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_bf16_autocast):
logits = model(input_ids)
token_losses = F.cross_entropy(
logits.flatten(end_dim=-2).float(),
labels.flatten(end_dim=-1),
ignore_index=-100,
reduction="none",
).view_as(labels)
valid_token_mask = labels != -100
loss = token_losses[valid_token_mask].mean() if valid_token_mask.any() else token_losses.new_zeros(())
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
loss = loss / args.grad_acc_steps
# backward on scaled loss to create scaled gradients
loss.backward()
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
## Accuracy calculation (for logging only, not used for training)
with torch.no_grad():
preds = logits.argmax(dim=-1)
valid = valid_token_mask
denom = valid.sum().clamp_min(1)
corrects = ((preds == labels) & valid).float().sum()
token_count = denom.float()
# BPB: sum nats and bytes only over valid (non-ignored) tokens
y1d = labels.reshape(-1)
valid1d = valid.reshape(-1)
ysafe = torch.where(valid1d, y1d, torch.zeros_like(y1d))
nb = torch.where(valid1d, token_bytes_tensor[ysafe], torch.zeros_like(y1d))
counted = nb > 0
bpb_nats_sum = token_losses.reshape(-1)[counted].sum()
bpb_bytes_sum = nb[counted].sum()
source_stats = None
if args.track_source_metrics:
source_stats = {}
for source_id, source_name in enumerate(source_names):
sample_mask = source_ids == source_id
source_token_mask = valid & sample_mask.unsqueeze(1)
source_token_count = source_token_mask.sum()
if source_token_count.item() > 0:
source_loss_sum = token_losses[source_token_mask].sum()
source_corrects = ((preds == labels) & source_token_mask).float().sum()
else:
source_loss_sum = token_losses.new_zeros(())
source_corrects = token_losses.new_zeros(())
source_stats[source_name] = {
"loss_sum": source_loss_sum,
"token_count": source_token_count.float(),
"corrects": source_corrects,
}
# optimizer step
grad_norm = -1.0
if train_state.acc_step == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
grad_norm = (
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
).item()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
train_state.step += 1
# updates the scale for next iteration
# training iteration complete
end_timer.record()
torch.cuda.synchronize()
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
# n_tokens += int((labels != -100).sum().item())
# log metrics
if every_n_steps(
train_state,
args.logging.freq,
acc_step=None if args.logging.acc_freq else 0,
acc_freq=args.logging.acc_freq,
):
time_delta = timer() - time_last_log
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
total_acc_steps = (
args.grad_acc_steps * train_state.step + train_state.acc_step
)
tokens_per_gpu = (
total_acc_steps * args.data.batch_size * args.data.seq_len
)
total_tokens = dp_degree * tokens_per_gpu
# This is an estimate and the correct values may change
# if you change the architecture
# Use xformer's analyze profile trace to get actual measurement
FLOPS = (
get_num_flop_per_token(
model_param_count - args.model.vocab_size * args.model.dim,
args.model.n_layers,
args.model.dim,
args.data.seq_len,
)
* wps
)
metrics = flatten_dict(
{
"global_step": train_state.step,
"acc_step": train_state.acc_step,
"speed": {
"wps": wps,
"FLOPS": FLOPS,
"curr_iter_time": curr_iter_time,
"data_load_time": data_load_time,
},
"optim": {
"grad_norm": grad_norm,
"lr": curr_lr,
"total_tokens": total_tokens,
},
},
sep="/",
)
to_sync = {}
to_sync["loss/out"] = loss.item()
to_sync["corrects/out"] = corrects.item()
to_sync["token_count/out"] = token_count.item()
if args.track_source_metrics and source_stats is not None:
for source_name in source_names:
stats = source_stats[source_name]
to_sync[f"sources/{source_name}/loss_sum"] = stats["loss_sum"].item()
to_sync[f"sources/{source_name}/token_count"] = stats["token_count"].item()
to_sync[f"sources/{source_name}/corrects"] = stats["corrects"].item()
synced_metrics = dist_mean_dict(to_sync)
synced_token_count = max(float(synced_metrics["token_count/out"]), 1e-8)
synced_metrics["accuracy/out"] = float(synced_metrics["corrects/out"]) / synced_token_count
# BPB: all_reduce sum nats and bytes across ranks, then divide
_bpb_nats = bpb_nats_sum.clone()
_bpb_bytes = bpb_bytes_sum.float().clone()
if dist.is_initialized() and dist.get_world_size() > 1:
dist.all_reduce(_bpb_nats, op=dist.ReduceOp.SUM)
dist.all_reduce(_bpb_bytes, op=dist.ReduceOp.SUM)
total_bytes = float(_bpb_bytes.item())
synced_metrics["bpb"] = float(_bpb_nats.item()) / (math.log(2) * total_bytes) if total_bytes > 0 else float("nan")
if args.track_source_metrics:
for source_name in source_names:
source_tokens = max(float(synced_metrics[f"sources/{source_name}/token_count"]), 1e-8)
synced_metrics[f"sources/{source_name}/loss"] = (
float(synced_metrics[f"sources/{source_name}/loss_sum"]) / source_tokens
)
synced_metrics.pop(f"sources/{source_name}/loss_sum")
synced_metrics[f"sources/{source_name}/accuracy"] = (
float(synced_metrics[f"sources/{source_name}/corrects"]) / source_tokens
)
synced_metrics.pop(f"sources/{source_name}/corrects")
metrics.update(synced_metrics)
if get_is_master():
metric_logger.log(metrics)
nwords_since_last_log = 0
time_last_log = timer()
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" bpb: {metrics['bpb']:.4f}"
f" accuracy: {metrics['accuracy/out']:>7}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
)
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
saved = checkpoint.save(
model,
optimizer,
train_state,
args,
device_mesh=world_mesh,
)
def main():
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
del cli_args.config
default_cfg = OmegaConf.structured(TrainAnswerOnlyArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
train(cfg)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
33.6 kB
·
Xet hash:
748828cc79ca111bac13373b6ef9dbbc41c51d0c98dd0142e9874b2b9733e0e7

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.