| | import multiprocessing
|
| | import time
|
| | from multiprocessing.managers import Namespace
|
| |
|
| | import torch
|
| | import numpy as np
|
| | from omegaconf import DictConfig, open_dict
|
| | from torch.optim import Optimizer
|
| | from torch.utils.data import DataLoader
|
| | from torch.optim.lr_scheduler import (
|
| | LRScheduler,
|
| | SequentialLR,
|
| | LinearLR,
|
| | CosineAnnealingLR,
|
| | )
|
| |
|
| | from osuT5.model.osu_t import OsuT
|
| | from osuT5.tokenizer import Tokenizer
|
| |
|
| |
|
| | def get_shared_training_state() -> Namespace:
|
| | mgr = multiprocessing.Manager()
|
| | shared = mgr.Namespace()
|
| | shared.current_train_step = 1
|
| | shared.current_epoch = 1
|
| | shared.last_log = time.time()
|
| | shared.current_loss = np.Infinity
|
| | shared.best_loss = np.Infinity
|
| | return shared
|
| |
|
| |
|
| | def get_model(args: DictConfig, tokenizer: Tokenizer) -> OsuT:
|
| | model = OsuT(args, tokenizer)
|
| | return model
|
| |
|
| |
|
| | def get_tokenizer(args: DictConfig) -> Tokenizer:
|
| | return Tokenizer(args)
|
| |
|
| |
|
| | def get_optimizer(model: OsuT, args: DictConfig) -> Optimizer:
|
| | no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"]
|
| |
|
| | optimizer_grouped_parameters = [
|
| | {
|
| | "params": [
|
| | p
|
| | for n, p in model.named_parameters()
|
| | if not any(nd in n for nd in no_decay)
|
| | ],
|
| | "weight_decay": args.optim.weight_decay,
|
| | },
|
| | {
|
| | "params": [
|
| | p
|
| | for n, p in model.named_parameters()
|
| | if any(nd in n for nd in no_decay)
|
| | ],
|
| | "weight_decay": 0.0,
|
| | },
|
| | ]
|
| |
|
| | if args.optim.name == 'adamw':
|
| | from transformers import AdamW
|
| | optimizer = AdamW(
|
| | optimizer_grouped_parameters,
|
| | lr=args.optim.base_lr,
|
| | )
|
| | elif args.optim.name == 'adamwscale':
|
| | from .copied_utils import AdamWScale
|
| | optimizer = AdamWScale(
|
| | optimizer_grouped_parameters,
|
| | lr=args.optim.base_lr,
|
| | )
|
| | elif args.optim.name == 'adafactor':
|
| | from transformers import Adafactor
|
| | optimizer = Adafactor(
|
| | optimizer_grouped_parameters,
|
| | lr=args.optim.base_lr,
|
| | relative_step=False,
|
| | )
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | return optimizer
|
| |
|
| |
|
| | def get_scheduler(optimizer: Optimizer, args: DictConfig) -> LRScheduler:
|
| | scheduler_p1 = LinearLR(
|
| | optimizer,
|
| | start_factor=0.5,
|
| | end_factor=1,
|
| | total_iters=args.optim.warmup_steps,
|
| | last_epoch=-1,
|
| | )
|
| |
|
| | scheduler_p2 = CosineAnnealingLR(
|
| | optimizer,
|
| | T_max=args.optim.total_steps - args.optim.warmup_steps,
|
| | eta_min=args.optim.final_cosine,
|
| | )
|
| |
|
| | scheduler = SequentialLR(
|
| | optimizer,
|
| | schedulers=[scheduler_p1, scheduler_p2],
|
| | milestones=[args.optim.warmup_steps],
|
| | )
|
| |
|
| | return scheduler
|
| |
|
| |
|
| |
|
| | def worker_init_fn(worker_id: int) -> None:
|
| | """
|
| | Give each dataloader a unique slice of the full dataset.
|
| | """
|
| | worker_info = torch.utils.data.get_worker_info()
|
| | dataset = worker_info.dataset
|
| | overall_start = dataset.start
|
| | overall_end = dataset.end
|
| |
|
| | per_worker = int(
|
| | np.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
|
| | )
|
| | dataset.start = overall_start + worker_id * per_worker
|
| | dataset.end = min(dataset.start + per_worker, overall_end)
|
| |
|