BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import gc
import os
import os.path as osp
import subprocess
import time
import re
from typing import List, Optional, Tuple
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import glob
import shutil
from infinity.utils import arg_util
import infinity.utils.dist as dist
def glob_with_epoch_iter(pattern, recursive=False):
def extract_ep_iter(filename):
match = re.search(r'ep(\d+)-iter(\d+)', filename)
if match:
ep = int(match.group(1))
iter_idx = int(match.group(2))
return ep, iter_idx
return 0, 0
return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
def glob_with_global_step(pattern, recursive=False):
def extract_ep_iter(filename):
match = re.search(r'global_step_(\d+)', filename)
if match:
iter_idx = int(match.group(1))
return iter_idx
return 0
return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
class CKPTSaver(object):
def __init__(self, is_master: bool, eval_milestone: List[Tuple[float, float]]):
self.is_master = is_master
self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device())
self.sp_also: subprocess.Popen = None
self.sp_best: subprocess.Popen = None
self.sp_backup: subprocess.Popen = None
self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone
def sav(
self, args: arg_util.Args, g_it: int, next_ep: int, next_it: int, trainer,
acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None,
also_save_to: str = None, best_save_to: str = None,
):
fname = f'global_step_{g_it}.pth'
local_out_ckpt = os.path.join(args.bed, fname)
trainer_state = trainer.state_dict()
stt = time.time()
if self.is_master:
torch.save({
'args': args.state_dict(),
'arch': args.model,
'epoch': next_ep,
'iter': next_it,
'trainer': trainer_state,
'acc_str': self.acc_str,
'g_it': g_it,
}, local_out_ckpt)
cost = time.time() - stt
print(f'Checkpoint save cost: {cost:.2f}s', flush=True)
print(f'Checkpoint save to: {local_out_ckpt}', flush=True)
del trainer_state
gc.collect(),
torch.cuda.empty_cache()
dist.barrier()
def auto_resume(args: arg_util.Args, pattern='*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]:
info = []
resume = ''
if args.auto_resume:
all_ckpt = glob_with_global_step(os.path.join(args.bed, pattern))
if len(all_ckpt) == 0:
info.append(f'[auto_resume] no ckpt found @ {pattern}')
info.append(f'[auto_resume quit]')
else:
resume = all_ckpt[0]
info.append(f'[auto_resume] auto load from @ {resume} ...')
else:
info.append(f'[auto_resume] disabled')
info.append(f'[auto_resume quit]')
if len(resume) == 0:
return info, 0, 0, '[no acc str]', [], {}, {}
print(f'auto resume from {resume}')
ckpt = torch.load(resume, map_location='cpu')
dist.barrier()
ep, it, g_it = ckpt['epoch'], ckpt['iter'], ckpt['g_it']
eval_milestone = ckpt.get('milestones', [])
info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}')
return info, ep, g_it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args']
def omnistore_auto_resume(args: arg_util.Args, pattern='ckpt*.pth'):
info = []
resume = ''
if args.auto_resume:
for dd in (args.local_out_path, args.bed):
all_ckpt = glob_with_global_step(os.path.join(dd, pattern))
if len(all_ckpt): break
if len(all_ckpt) == 0:
info.append(f'[auto_resume] no ckpt found @ {pattern}')
info.append(f'[auto_resume quit]')
else:
resume = all_ckpt[0]
info.append(f'[auto_resume] auto load from @ {resume} ...')
else:
info.append(f'[auto_resume] disabled')
info.append(f'[auto_resume quit]')
return resume, info
class omnistoreCheckpoint(object):
def __init__(self, eval_milestone: List[Tuple[float, float]]):
self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device())
self.sp_also: subprocess.Popen = None
self.sp_best: subprocess.Popen = None
self.sp_backup: subprocess.Popen = None
self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone
def sav(
self, args: arg_util.Args, global_it: int, next_ep: int, next_it: int, fsdp_object: FSDP, optimizer_object: torch.optim.Optimizer,
acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None,
):
if acc_str is not None: self.acc_str = acc_str
if eval_milestone is not None: self.eval_milestone = eval_milestone
stt = time.time()
checkpoint_state = {
# 'model': {
# 'main_model': fsdp_object,
# 'ema_model': ema_fsdp_object,
# },
'model': fsdp_object,
# 'optimizer': optimizer_object,
'extra_state': {}
}
from omnistore import FSDPCheckpointer
print(f"{FSDPCheckpointer=}")
FSDPCheckpointer.save(
path=args.bed,
checkpoint_state=checkpoint_state,
global_steps=global_it,
async_fast_checkpoint=True,
save_flatten_model_optimizer=True,
)
if dist.is_master():
torch.save({
'args': args.state_dict(),
'next_ep': next_ep,
'next_it': next_it,
'global_it': global_it,
'acc_str': self.acc_str,
'milestones': self.eval_milestone,
}, os.path.join(args.bed, 'meta.pth'))
if self.sp_backup is not None:
self.sp_backup.wait(timeout=300); self.sp_backup.kill(); self.sp_backup.communicate()
self.time_stamp[0] = time.time()
def auto_sync(source_filename, target_filename):
cmd = f'cp -r {source_filename} {target_filename}'
self.sp_backup = subprocess.Popen(cmd, shell=True, bufsize=-1)
print(f'[Saver] auto_save cmd: {cmd}', flush=True)
local_files = glob.glob(f"{args.local_out_path}/*.txt")
for filename in local_files:
basename = os.path.basename(filename)
target_filename = f'{args.bed}/{basename}'
auto_sync(filename, target_filename)
cost = time.time() - stt
print(f'[CKPTSaver][rank00][omnistore: {FSDPCheckpointer is not None}] cost={time.time()-stt:.2f}s, ckpt saved to global_step_{global_it}', flush=True)
dist.barrier()
del checkpoint_state
def load(self, ckpt_path, fsdp_object, optimizer_object):
from omnistore import FSDPCheckpointer
checkpoint_state = {
'model': fsdp_object,
# 'optimizer': optimizer_object,
'extra_state': {}
}
FSDPCheckpointer.load(
ckpt_path,
checkpoint_state,
load_flatten_model_optimizer=True,
)
global_it = -1
meta_path = os.path.join(os.path.dirname(ckpt_path), 'meta.pth')
if os.path.exists(meta_path):
train_meta = torch.load(meta_path)
args_state, next_ep, next_it, acc_str, milestones = train_meta['args'], train_meta['next_ep'], train_meta['next_it'], train_meta['acc_str'], train_meta['milestones']
global_it = train_meta.get('global_it', -1)
else:
args_state, next_ep, next_it, acc_str, milestones = {}, 0, 0, '', []
return args_state, next_ep, next_it, global_it, acc_str, milestones
def merge_ckpt(omnistore_ckpt_path, output_path, fsdp_save_flatten_model, save=False):
print(f'merging omnistore ckpt into torch-format ckpt')
start = time.time()
from omnistore.utilities.ckpt_format_tool import omnistore_ckpt_to_pytorch_ckpt
state_dict = omnistore_ckpt_to_pytorch_ckpt(
save_path=omnistore_ckpt_path,
output_path=output_path,
framework="fsdp",
model_only=True,
return_dict=True,
fsdp_save_flatten_model=fsdp_save_flatten_model,
)
print(f"ckpt merged in {time.time() - start:.2f} seconds")
state_dict_model = state_dict["model"]
if '.cfg_uncond' in state_dict_model:
state_dict_model['cfg_uncond'] = state_dict_model['.cfg_uncond']
del state_dict_model['.cfg_uncond']
if '.pos_start' in state_dict_model:
state_dict_model['pos_start'] = state_dict_model['.pos_start']
del state_dict_model['.pos_start']
if '.sos_token' in state_dict_model:
state_dict_model['sos_token'] = state_dict_model['.sos_token']
del state_dict_model['.sos_token']
if 'semantic_head.weight' in state_dict_model:
print(f'[rush_resume] replace semantic_head with semantic_head2')
state_dict_model['semantic_head2.weight'] = state_dict_model['semantic_head.weight']
state_dict_model['semantic_head2.bias'] = state_dict_model['semantic_head.bias']
del state_dict_model['semantic_head.weight']
del state_dict_model['semantic_head.bias']
if save:
save_file = os.path.join(output_path, "slim-model.pt")
print(f'save to {save_file}')
torch.save(state_dict_model, save_file)
return state_dict_model