File size: 3,571 Bytes
49235ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import multiprocessing
import time
from multiprocessing.managers import Namespace
import torch
import numpy as np
from omegaconf import DictConfig, open_dict
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import (
LRScheduler,
SequentialLR,
LinearLR,
CosineAnnealingLR,
)
from osuT5.model.osu_t import OsuT
from osuT5.tokenizer import Tokenizer
def get_shared_training_state() -> Namespace:
mgr = multiprocessing.Manager()
shared = mgr.Namespace()
shared.current_train_step = 1
shared.current_epoch = 1
shared.last_log = time.time()
shared.current_loss = np.Infinity
shared.best_loss = np.Infinity
return shared
def get_model(args: DictConfig, tokenizer: Tokenizer) -> OsuT:
model = OsuT(args, tokenizer)
return model
def get_tokenizer(args: DictConfig) -> Tokenizer:
return Tokenizer(args)
def get_optimizer(model: OsuT, args: DictConfig) -> Optimizer:
no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": args.optim.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
if args.optim.name == 'adamw':
from transformers import AdamW
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
)
elif args.optim.name == 'adamwscale':
from .copied_utils import AdamWScale
optimizer = AdamWScale(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
)
elif args.optim.name == 'adafactor':
from transformers import Adafactor
optimizer = Adafactor(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
relative_step=False,
)
else:
raise NotImplementedError
return optimizer
def get_scheduler(optimizer: Optimizer, args: DictConfig) -> LRScheduler:
scheduler_p1 = LinearLR(
optimizer,
start_factor=0.5,
end_factor=1,
total_iters=args.optim.warmup_steps,
last_epoch=-1,
)
scheduler_p2 = CosineAnnealingLR(
optimizer,
T_max=args.optim.total_steps - args.optim.warmup_steps,
eta_min=args.optim.final_cosine,
)
scheduler = SequentialLR(
optimizer,
schedulers=[scheduler_p1, scheduler_p2],
milestones=[args.optim.warmup_steps],
)
return scheduler
def worker_init_fn(worker_id: int) -> None:
"""
Give each dataloader a unique slice of the full dataset.
"""
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
overall_start = dataset.start
overall_end = dataset.end
# configure the dataset to only process the split workload
per_worker = int(
np.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
)
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)
|