|
|
|
|
|
|
|
|
|
|
|
import gc |
|
|
import os |
|
|
import os.path as osp |
|
|
import subprocess |
|
|
import time |
|
|
import re |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
|
|
|
import glob |
|
|
import shutil |
|
|
from infinity.utils import arg_util |
|
|
import infinity.utils.dist as dist |
|
|
|
|
|
|
|
|
def glob_with_epoch_iter(pattern, recursive=False): |
|
|
def extract_ep_iter(filename): |
|
|
match = re.search(r'ep(\d+)-iter(\d+)', filename) |
|
|
if match: |
|
|
ep = int(match.group(1)) |
|
|
iter_idx = int(match.group(2)) |
|
|
return ep, iter_idx |
|
|
return 0, 0 |
|
|
return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True) |
|
|
|
|
|
|
|
|
def glob_with_global_step(pattern, recursive=False): |
|
|
def extract_ep_iter(filename): |
|
|
match = re.search(r'global_step_(\d+)', filename) |
|
|
if match: |
|
|
iter_idx = int(match.group(1)) |
|
|
return iter_idx |
|
|
return 0 |
|
|
return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True) |
|
|
|
|
|
|
|
|
class CKPTSaver(object): |
|
|
def __init__(self, is_master: bool, eval_milestone: List[Tuple[float, float]]): |
|
|
self.is_master = is_master |
|
|
self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device()) |
|
|
self.sp_also: subprocess.Popen = None |
|
|
self.sp_best: subprocess.Popen = None |
|
|
self.sp_backup: subprocess.Popen = None |
|
|
self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone |
|
|
|
|
|
def sav( |
|
|
self, args: arg_util.Args, g_it: int, next_ep: int, next_it: int, trainer, |
|
|
acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None, |
|
|
also_save_to: str = None, best_save_to: str = None, |
|
|
): |
|
|
fname = f'global_step_{g_it}.pth' |
|
|
local_out_ckpt = os.path.join(args.bed, fname) |
|
|
trainer_state = trainer.state_dict() |
|
|
stt = time.time() |
|
|
if self.is_master: |
|
|
torch.save({ |
|
|
'args': args.state_dict(), |
|
|
'arch': args.model, |
|
|
'epoch': next_ep, |
|
|
'iter': next_it, |
|
|
'trainer': trainer_state, |
|
|
'acc_str': self.acc_str, |
|
|
'g_it': g_it, |
|
|
}, local_out_ckpt) |
|
|
cost = time.time() - stt |
|
|
print(f'Checkpoint save cost: {cost:.2f}s', flush=True) |
|
|
print(f'Checkpoint save to: {local_out_ckpt}', flush=True) |
|
|
|
|
|
del trainer_state |
|
|
gc.collect(), |
|
|
torch.cuda.empty_cache() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
def auto_resume(args: arg_util.Args, pattern='*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]: |
|
|
info = [] |
|
|
resume = '' |
|
|
if args.auto_resume: |
|
|
all_ckpt = glob_with_global_step(os.path.join(args.bed, pattern)) |
|
|
if len(all_ckpt) == 0: |
|
|
info.append(f'[auto_resume] no ckpt found @ {pattern}') |
|
|
info.append(f'[auto_resume quit]') |
|
|
else: |
|
|
resume = all_ckpt[0] |
|
|
info.append(f'[auto_resume] auto load from @ {resume} ...') |
|
|
else: |
|
|
info.append(f'[auto_resume] disabled') |
|
|
info.append(f'[auto_resume quit]') |
|
|
|
|
|
if len(resume) == 0: |
|
|
return info, 0, 0, '[no acc str]', [], {}, {} |
|
|
|
|
|
print(f'auto resume from {resume}') |
|
|
ckpt = torch.load(resume, map_location='cpu') |
|
|
|
|
|
dist.barrier() |
|
|
ep, it, g_it = ckpt['epoch'], ckpt['iter'], ckpt['g_it'] |
|
|
eval_milestone = ckpt.get('milestones', []) |
|
|
info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}') |
|
|
return info, ep, g_it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args'] |
|
|
|
|
|
def omnistore_auto_resume(args: arg_util.Args, pattern='ckpt*.pth'): |
|
|
info = [] |
|
|
resume = '' |
|
|
if args.auto_resume: |
|
|
for dd in (args.local_out_path, args.bed): |
|
|
all_ckpt = glob_with_global_step(os.path.join(dd, pattern)) |
|
|
if len(all_ckpt): break |
|
|
if len(all_ckpt) == 0: |
|
|
info.append(f'[auto_resume] no ckpt found @ {pattern}') |
|
|
info.append(f'[auto_resume quit]') |
|
|
else: |
|
|
resume = all_ckpt[0] |
|
|
info.append(f'[auto_resume] auto load from @ {resume} ...') |
|
|
else: |
|
|
info.append(f'[auto_resume] disabled') |
|
|
info.append(f'[auto_resume quit]') |
|
|
|
|
|
return resume, info |
|
|
|
|
|
|
|
|
class omnistoreCheckpoint(object): |
|
|
def __init__(self, eval_milestone: List[Tuple[float, float]]): |
|
|
self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device()) |
|
|
self.sp_also: subprocess.Popen = None |
|
|
self.sp_best: subprocess.Popen = None |
|
|
self.sp_backup: subprocess.Popen = None |
|
|
self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone |
|
|
|
|
|
def sav( |
|
|
self, args: arg_util.Args, global_it: int, next_ep: int, next_it: int, fsdp_object: FSDP, optimizer_object: torch.optim.Optimizer, |
|
|
acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None, |
|
|
): |
|
|
if acc_str is not None: self.acc_str = acc_str |
|
|
if eval_milestone is not None: self.eval_milestone = eval_milestone |
|
|
|
|
|
stt = time.time() |
|
|
|
|
|
checkpoint_state = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'model': fsdp_object, |
|
|
|
|
|
'extra_state': {} |
|
|
} |
|
|
|
|
|
from omnistore import FSDPCheckpointer |
|
|
print(f"{FSDPCheckpointer=}") |
|
|
|
|
|
FSDPCheckpointer.save( |
|
|
path=args.bed, |
|
|
checkpoint_state=checkpoint_state, |
|
|
global_steps=global_it, |
|
|
async_fast_checkpoint=True, |
|
|
save_flatten_model_optimizer=True, |
|
|
) |
|
|
if dist.is_master(): |
|
|
torch.save({ |
|
|
'args': args.state_dict(), |
|
|
'next_ep': next_ep, |
|
|
'next_it': next_it, |
|
|
'global_it': global_it, |
|
|
'acc_str': self.acc_str, |
|
|
'milestones': self.eval_milestone, |
|
|
}, os.path.join(args.bed, 'meta.pth')) |
|
|
|
|
|
if self.sp_backup is not None: |
|
|
self.sp_backup.wait(timeout=300); self.sp_backup.kill(); self.sp_backup.communicate() |
|
|
self.time_stamp[0] = time.time() |
|
|
def auto_sync(source_filename, target_filename): |
|
|
cmd = f'cp -r {source_filename} {target_filename}' |
|
|
self.sp_backup = subprocess.Popen(cmd, shell=True, bufsize=-1) |
|
|
print(f'[Saver] auto_save cmd: {cmd}', flush=True) |
|
|
local_files = glob.glob(f"{args.local_out_path}/*.txt") |
|
|
for filename in local_files: |
|
|
basename = os.path.basename(filename) |
|
|
target_filename = f'{args.bed}/{basename}' |
|
|
auto_sync(filename, target_filename) |
|
|
cost = time.time() - stt |
|
|
print(f'[CKPTSaver][rank00][omnistore: {FSDPCheckpointer is not None}] cost={time.time()-stt:.2f}s, ckpt saved to global_step_{global_it}', flush=True) |
|
|
|
|
|
dist.barrier() |
|
|
del checkpoint_state |
|
|
|
|
|
def load(self, ckpt_path, fsdp_object, optimizer_object): |
|
|
from omnistore import FSDPCheckpointer |
|
|
checkpoint_state = { |
|
|
'model': fsdp_object, |
|
|
|
|
|
'extra_state': {} |
|
|
} |
|
|
FSDPCheckpointer.load( |
|
|
ckpt_path, |
|
|
checkpoint_state, |
|
|
load_flatten_model_optimizer=True, |
|
|
) |
|
|
global_it = -1 |
|
|
meta_path = os.path.join(os.path.dirname(ckpt_path), 'meta.pth') |
|
|
if os.path.exists(meta_path): |
|
|
train_meta = torch.load(meta_path) |
|
|
args_state, next_ep, next_it, acc_str, milestones = train_meta['args'], train_meta['next_ep'], train_meta['next_it'], train_meta['acc_str'], train_meta['milestones'] |
|
|
global_it = train_meta.get('global_it', -1) |
|
|
else: |
|
|
args_state, next_ep, next_it, acc_str, milestones = {}, 0, 0, '', [] |
|
|
return args_state, next_ep, next_it, global_it, acc_str, milestones |
|
|
|
|
|
def merge_ckpt(omnistore_ckpt_path, output_path, fsdp_save_flatten_model, save=False): |
|
|
print(f'merging omnistore ckpt into torch-format ckpt') |
|
|
start = time.time() |
|
|
from omnistore.utilities.ckpt_format_tool import omnistore_ckpt_to_pytorch_ckpt |
|
|
state_dict = omnistore_ckpt_to_pytorch_ckpt( |
|
|
save_path=omnistore_ckpt_path, |
|
|
output_path=output_path, |
|
|
framework="fsdp", |
|
|
model_only=True, |
|
|
return_dict=True, |
|
|
fsdp_save_flatten_model=fsdp_save_flatten_model, |
|
|
) |
|
|
print(f"ckpt merged in {time.time() - start:.2f} seconds") |
|
|
state_dict_model = state_dict["model"] |
|
|
if '.cfg_uncond' in state_dict_model: |
|
|
state_dict_model['cfg_uncond'] = state_dict_model['.cfg_uncond'] |
|
|
del state_dict_model['.cfg_uncond'] |
|
|
if '.pos_start' in state_dict_model: |
|
|
state_dict_model['pos_start'] = state_dict_model['.pos_start'] |
|
|
del state_dict_model['.pos_start'] |
|
|
if '.sos_token' in state_dict_model: |
|
|
state_dict_model['sos_token'] = state_dict_model['.sos_token'] |
|
|
del state_dict_model['.sos_token'] |
|
|
if 'semantic_head.weight' in state_dict_model: |
|
|
print(f'[rush_resume] replace semantic_head with semantic_head2') |
|
|
state_dict_model['semantic_head2.weight'] = state_dict_model['semantic_head.weight'] |
|
|
state_dict_model['semantic_head2.bias'] = state_dict_model['semantic_head.bias'] |
|
|
del state_dict_model['semantic_head.weight'] |
|
|
del state_dict_model['semantic_head.bias'] |
|
|
if save: |
|
|
save_file = os.path.join(output_path, "slim-model.pt") |
|
|
print(f'save to {save_file}') |
|
|
torch.save(state_dict_model, save_file) |
|
|
return state_dict_model |
|
|
|