osu_mapper / osuT5 /utils /model_utils.py
Tiger14n's picture
Upload folder using huggingface_hub
49235ad verified
import multiprocessing
import time
from multiprocessing.managers import Namespace
import torch
import numpy as np
from omegaconf import DictConfig, open_dict
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import (
LRScheduler,
SequentialLR,
LinearLR,
CosineAnnealingLR,
)
from osuT5.model.osu_t import OsuT
from osuT5.tokenizer import Tokenizer
def get_shared_training_state() -> Namespace:
mgr = multiprocessing.Manager()
shared = mgr.Namespace()
shared.current_train_step = 1
shared.current_epoch = 1
shared.last_log = time.time()
shared.current_loss = np.Infinity
shared.best_loss = np.Infinity
return shared
def get_model(args: DictConfig, tokenizer: Tokenizer) -> OsuT:
model = OsuT(args, tokenizer)
return model
def get_tokenizer(args: DictConfig) -> Tokenizer:
return Tokenizer(args)
def get_optimizer(model: OsuT, args: DictConfig) -> Optimizer:
no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": args.optim.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
if args.optim.name == 'adamw':
from transformers import AdamW
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
)
elif args.optim.name == 'adamwscale':
from .copied_utils import AdamWScale
optimizer = AdamWScale(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
)
elif args.optim.name == 'adafactor':
from transformers import Adafactor
optimizer = Adafactor(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
relative_step=False,
)
else:
raise NotImplementedError
return optimizer
def get_scheduler(optimizer: Optimizer, args: DictConfig) -> LRScheduler:
scheduler_p1 = LinearLR(
optimizer,
start_factor=0.5,
end_factor=1,
total_iters=args.optim.warmup_steps,
last_epoch=-1,
)
scheduler_p2 = CosineAnnealingLR(
optimizer,
T_max=args.optim.total_steps - args.optim.warmup_steps,
eta_min=args.optim.final_cosine,
)
scheduler = SequentialLR(
optimizer,
schedulers=[scheduler_p1, scheduler_p2],
milestones=[args.optim.warmup_steps],
)
return scheduler
def worker_init_fn(worker_id: int) -> None:
"""
Give each dataloader a unique slice of the full dataset.
"""
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
overall_start = dataset.start
overall_end = dataset.end
# configure the dataset to only process the split workload
per_worker = int(
np.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
)
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)