osu_mapper2 / osuT5 /utils /train_utils.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
import glob
import os.path
import time
from multiprocessing.managers import Namespace
import torch
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from omegaconf import DictConfig
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from osuT5.tokenizer import Tokenizer, EventType
from osuT5.model import OsuT
from .log_utils import Averager
logger = get_logger(__name__)
def forward(model: OsuT, batch):
outputs = model(**batch)
loss = outputs.loss
stats = {"loss": loss.detach()}
return loss, stats
def forward_eval(model: OsuT, batch):
outputs = model(**batch)
return outputs
def add_prefix(prefix: str, stats: dict[str, float]):
return {f"{prefix}/{k}": v for k, v in stats.items()}
def maybe_save_checkpoint(accelerator: Accelerator, args: DictConfig, shared: Namespace):
if (
shared.current_train_step > args.optim.total_steps
or shared.current_train_step % args.checkpoint.every_steps == 0
):
if shared.current_loss < shared.best_loss:
shared.best_loss = shared.current_loss
is_best = True
else:
is_best = False
output_dir = f"checkpoint-{shared.current_train_step}"
accelerator.wait_for_everyone()
# Saving T5 has an issue that safe serialization removes shared tensors and then the model can't be loaded.
accelerator.save_state(output_dir=output_dir, safe_serialization=False)
wandb_tracker = accelerator.get_tracker("wandb")
if wandb_tracker is not None:
art = wandb.Artifact(
f"osuT5-{wandb.run.id}",
type="model",
metadata={
"format": "accelerate",
"src_seq_len": args.data.src_seq_len,
"tgt_seq_len": args.data.tgt_seq_len,
"num_classes": args.data.num_classes,
"num_diff_classes": args.data.num_diff_classes,
"max_difficulty": args.data.max_diff,
"class_dropout_prob": args.data.class_dropout_prob,
"diff_dropout_prob": args.data.diff_dropout_prob,
"spectrogram": args.model.spectrogram,
"current_train_step": shared.current_train_step,
"current_epoch": shared.current_epoch,
"current_loss": shared.current_loss,
},
)
for file in os.listdir(output_dir):
art.add_file(os.path.join(output_dir, file))
wandb.log_artifact(art, aliases=["best"] if is_best else None)
logger.info(f"Logged checkpoint to wandb: {art.name}")
def maybe_eval(
model: OsuT,
accelerator: Accelerator,
dataloader: DataLoader,
tokenizer: Tokenizer,
args: DictConfig,
shared: Namespace,
):
if (
shared.current_train_step > args.optim.total_steps
or shared.current_train_step % args.eval.every_steps == 0
):
model.eval()
with torch.no_grad():
eval_model(model, accelerator, dataloader, tokenizer, args, shared)
shared.last_log = time.time()
model.train()
def maybe_logging(
model: OsuT,
accelerator: Accelerator,
optimizer: Optimizer,
averager: Averager,
args: DictConfig,
shared: Namespace,
):
def extra_stats(args, shared, model, optimizer):
stats = {}
if args.logging.weights_l2:
weights_l2 = (
sum(p.detach().norm(2).item() ** 2 for p in model.parameters() if p.requires_grad) ** 0.5
)
stats["weights_l2"] = weights_l2
stats["lr"] = optimizer.param_groups[0]["lr"]
stats["seconds_per_step"] = (
time.time() - shared.last_log
) / args.logging.every_steps
return stats
if shared.current_train_step % args.logging.every_steps == 0:
stats = extra_stats(args, shared, model, optimizer)
averager.update(stats)
averaged_stats = averager.average()
averaged_stats["epoch"] = shared.current_epoch
averaged_stats = add_prefix("train", averaged_stats)
accelerator.log(averaged_stats, step=shared.current_train_step)
averaged_stats["step"] = shared.current_train_step
logger.info(averaged_stats)
shared.last_log = time.time()
def maybe_grad_clip_and_grad_calc(
model: OsuT,
accelerator: Accelerator,
args: DictConfig,
):
if args.optim.grad_clip > 0:
grad_l2 = accelerator.clip_grad_norm_(
parameters=model.parameters(),
max_norm=args.optim.grad_clip,
norm_type=2,
).item()
else:
grad_l2 = None
if args.logging.grad_l2:
if grad_l2 is None:
grad_l2 = (
sum(
p.grad.detach().data.norm(2).item() ** 2 for p in model.parameters()
)
** 0.5
)
return {"grad_l2": grad_l2}
else:
return {}
# noinspection PyUnresolvedReferences,PyTypeChecker
def eval_model(
model: OsuT,
accelerator: Accelerator,
dataloader: DataLoader,
tokenizer: Tokenizer,
args: DictConfig,
shared: Namespace,
):
shared.last_log = time.time()
averager = Averager()
for batch_id, batch in enumerate(dataloader, start=1):
if batch_id == args.eval.steps * args.optim.grad_acc:
break
# We can't use the beatmap idx of the test set because these are not known by the model
del batch["beatmap_idx"]
outputs = forward_eval(model, batch)
# Reduce loss over all processes
loss = outputs.loss
loss = accelerator.reduce(loss, reduction="mean")
# Gether labels and predictions over all processes and drop duplicates
preds = torch.argmax(outputs.logits, dim=-1)
labels = batch["labels"]
accelerator.gather_for_metrics((preds, labels))
# Calculate accuracy metrics
stats = {"loss": loss.detach(),
"timing_acc": acc_range(preds, labels, tokenizer.event_start[EventType.TIME_SHIFT],
tokenizer.event_end[EventType.TIME_SHIFT]),
"spacing_acc": acc_range(preds, labels, tokenizer.event_start[EventType.DISTANCE],
tokenizer.event_end[EventType.DISTANCE]),
"other_acc": acc_range(preds, labels, tokenizer.event_end[EventType.DISTANCE],
tokenizer.event_end[EventType.DISTANCE] + tokenizer.vocab_size_out)}
averager.update(stats)
averager.update({"time": time.time() - shared.last_log})
averaged_stats = averager.average()
averaged_stats = add_prefix("test", averaged_stats)
accelerator.log(averaged_stats, step=shared.current_train_step)
logger.info(averaged_stats)
shared.current_loss = averaged_stats["test/loss"]
def acc_range(preds, labels, start_index, end_index):
index = (start_index <= labels) & (labels < end_index)
range_labels = labels[index]
range_preds = preds[index]
return (range_preds == range_labels).detach().float().cpu().numpy()
def train(
model: OsuT,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
accelerator: Accelerator,
lr_scheduler: LRScheduler,
optimizer: Optimizer,
tokenizer: Tokenizer,
args: DictConfig,
shared: Namespace,
profiler=None,
):
model.train()
train_averager = Averager()
while shared.current_train_step <= args.optim.total_steps:
# In case there is a remainder from previous epoch, we need to reset the optimizer
optimizer.zero_grad(set_to_none=True)
accelerator.print(f"Epoch {shared.current_epoch}")
for batch_id, batch in enumerate(train_dataloader, start=1):
with accelerator.accumulate(model):
if shared.current_train_step > args.optim.total_steps:
break
loss, stats = forward(model, batch)
accelerator.backward(loss)
train_averager.update(stats)
if accelerator.sync_gradients:
stats = maybe_grad_clip_and_grad_calc(model, accelerator, args)
train_averager.update(stats)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if profiler is not None:
profiler.step()
if accelerator.sync_gradients:
maybe_logging(model, accelerator, optimizer, train_averager, args, shared)
maybe_eval(model, accelerator, test_dataloader, tokenizer, args, shared)
maybe_save_checkpoint(accelerator, args, shared)
shared.current_train_step += 1
shared.current_epoch += 1
if not (args.profile.do_profile and args.profile.early_stop):
maybe_eval(model, accelerator, test_dataloader, tokenizer, args, shared)
maybe_save_checkpoint(accelerator, args, shared)
accelerator.end_training()
def train_profiling(
model: OsuT,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
accelerator: Accelerator,
lr_scheduler: LRScheduler,
optimizer: Optimizer,
tokenizer: Tokenizer,
args: DictConfig,
shared: Namespace,
):
tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler(
"./profiler_logs", worker_name=f"worker_{accelerator.process_index}")
if args.profile.early_stop:
stop_step = (args.profile.wait + args.profile.warmup + args.profile.active) * args.profile.repeat / args.optim.grad_acc
args.optim.total_steps = shared.current_train_step + stop_step
def on_trace_ready(trace):
tensorboard_trace_handler(trace)
wandb_tracker = accelerator.get_tracker("wandb")
if wandb_tracker is not None:
wandb.save(glob.glob(f"./profiler_logs/*.pt.trace.json")[0], base_path="profiler_logs")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=args.profile.wait,
warmup=args.profile.warmup,
active=args.profile.active,
repeat=args.profile.repeat,
),
on_trace_ready=on_trace_ready,
record_shapes=True,
profile_memory=True,
with_stack=True,
) as p:
train(
model,
train_dataloader,
test_dataloader,
accelerator,
lr_scheduler,
optimizer,
tokenizer,
args,
shared,
p
)