|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
import concurrent.futures |
|
|
import hashlib |
|
|
import io |
|
|
import os |
|
|
import pickle |
|
|
import re |
|
|
import socket |
|
|
import stat |
|
|
from asyncio import InvalidStateError |
|
|
from asyncio.tasks import ALL_COMPLETED |
|
|
from datetime import datetime |
|
|
from typing import Any, Awaitable, Callable, Dict, List, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
from internlm.core.context import global_context as gpc |
|
|
from internlm.utils.common import SingletonMeta |
|
|
from internlm.utils.logger import get_logger |
|
|
|
|
|
try: |
|
|
import boto3 |
|
|
import botocore |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
logger = get_logger(__file__) |
|
|
|
|
|
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") |
|
|
|
|
|
MB = 1024**2 |
|
|
|
|
|
storage_manager = None |
|
|
|
|
|
|
|
|
def check_folder(fp: str): |
|
|
storage_manager.assert_fp_exists(fp) |
|
|
|
|
|
|
|
|
def get_fns(fp: str): |
|
|
return storage_manager.get_fns(fp) |
|
|
|
|
|
|
|
|
def llm_load(fp: str, **kwargs): |
|
|
return storage_manager.load(fp, **kwargs) |
|
|
|
|
|
|
|
|
def llm_save(save_path: str, saved_obj: Any, **kwargs): |
|
|
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) |
|
|
|
|
|
|
|
|
class StorageClient: |
|
|
""" |
|
|
StorageClient as a client for s3 storage access. |
|
|
""" |
|
|
|
|
|
def __init__(self, handler) -> None: |
|
|
self.handler = handler |
|
|
|
|
|
@staticmethod |
|
|
def load(*args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def sync_upload_fileobj(*args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def async_upload_fileobj(*args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def assert_fp_exists(*args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def get_fns(*args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class Boto3MetaInfo: |
|
|
"""Boto3 meta info for save/load etc.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
is_async, |
|
|
handler: StorageClient, |
|
|
bucket_name: str, |
|
|
endpoint: str, |
|
|
file_path: str, |
|
|
async_upload_fn: callable, |
|
|
local_nvme_path=None, |
|
|
) -> None: |
|
|
|
|
|
self.client = handler |
|
|
self.bucket_name = bucket_name |
|
|
self.file_path = file_path |
|
|
|
|
|
self.local_nvme_path = local_nvme_path |
|
|
self.is_async = is_async |
|
|
self.endpoint = endpoint |
|
|
self.async_upload_fn = async_upload_fn |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \ |
|
|
local_nvme_path: {self.local_nvme_path}" |
|
|
|
|
|
@staticmethod |
|
|
def unpack_boto3_save_meta(meta): |
|
|
if meta.is_async: |
|
|
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path |
|
|
else: |
|
|
return meta.client, meta.bucket_name, meta.file_path |
|
|
|
|
|
@staticmethod |
|
|
def unpack_boto3_nosave_meta(meta): |
|
|
return meta.client, meta.bucket_name, meta.file_path |
|
|
|
|
|
|
|
|
class LocalMetaInfo: |
|
|
"""Local meta info for save/load etc.""" |
|
|
|
|
|
def __init__(self, file_path: str) -> None: |
|
|
self.file_path = file_path |
|
|
self.async_upload_fn = None |
|
|
self.is_async = False |
|
|
|
|
|
@staticmethod |
|
|
def unpack_local_save_meta(meta): |
|
|
return (meta.file_path,) |
|
|
|
|
|
@staticmethod |
|
|
def unpack_local_nosave_meta(meta): |
|
|
return (meta.file_path,) |
|
|
|
|
|
|
|
|
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): |
|
|
if isinstance(meta, Boto3MetaInfo): |
|
|
return Boto3MetaInfo.unpack_boto3_save_meta(meta) |
|
|
elif isinstance(meta, LocalMetaInfo): |
|
|
return LocalMetaInfo.unpack_local_save_meta(meta) |
|
|
else: |
|
|
raise ValueError(f"unkonwn meta info: {type(meta)}") |
|
|
|
|
|
|
|
|
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): |
|
|
if isinstance(meta, Boto3MetaInfo): |
|
|
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta) |
|
|
elif isinstance(meta, LocalMetaInfo): |
|
|
return LocalMetaInfo.unpack_local_nosave_meta(meta) |
|
|
else: |
|
|
raise ValueError(f"unkonwn meta info: {type(meta)}") |
|
|
|
|
|
|
|
|
def compute_file_md5_by_chunk(file_name: str): |
|
|
hash_md5 = hashlib.md5() |
|
|
with open(file_name, "rb") as f: |
|
|
for chunk in iter(lambda: f.read(4096), b""): |
|
|
hash_md5.update(chunk) |
|
|
return hash_md5.hexdigest() |
|
|
|
|
|
|
|
|
def try_get_storage_backend(path: str): |
|
|
sre = path.split(":", maxsplit=1) |
|
|
if len(sre) == 1: |
|
|
if path.startswith("s3:"): |
|
|
backend = "boto3" |
|
|
if gpc.is_rank_for_log(): |
|
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") |
|
|
else: |
|
|
backend = "local" |
|
|
if gpc.is_rank_for_log(): |
|
|
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") |
|
|
return backend, sre |
|
|
else: |
|
|
return sre[0], sre[1] |
|
|
|
|
|
|
|
|
class Boto3Client(StorageClient): |
|
|
""" |
|
|
Boto3Client |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
s3_endpoint_url: str, |
|
|
use_threads: int = True, |
|
|
multipart_chunksize=8 * MB, |
|
|
max_concurrency: int = 10, |
|
|
multipart_threshold=100 * MB, |
|
|
) -> None: |
|
|
"""S3 object/file storage management class |
|
|
|
|
|
Args: |
|
|
s3_access_keys_id (str): S3 access key ID. |
|
|
s3_secret_access_key (str): S3 secret access key. |
|
|
use_threads (bool, optional): Whether to enable multipart. Defaults to True. |
|
|
multipart_chunksize (_type_, optional): Defaults to 8*MB. |
|
|
max_concurrency (int, optional): Defaults to 10. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: Connection failures caused by misconfiguration or network problems. |
|
|
""" |
|
|
super().__init__(boto3) |
|
|
self.botocore = botocore |
|
|
try: |
|
|
s3_access_key_id = os.environ["S3_ACCESS_KEY_ID"] |
|
|
s3_secret_access_key = os.environ["S3_SECRET_ACCESS_KEY_ID"] |
|
|
except KeyError as exc: |
|
|
raise RuntimeError( |
|
|
"Please set boto3 bucket 'S3_ACCESS_KEY_ID' and 'S3_SECRET_ACCESS_KEY_ID' using environment variable!" |
|
|
) from exc |
|
|
|
|
|
self.client = self.handler.client( |
|
|
"s3", |
|
|
"", |
|
|
use_ssl=False, |
|
|
verify=False, |
|
|
endpoint_url=s3_endpoint_url, |
|
|
aws_access_key_id=s3_access_key_id, |
|
|
aws_secret_access_key=s3_secret_access_key, |
|
|
) |
|
|
|
|
|
self.config = self.handler.s3.transfer.TransferConfig( |
|
|
multipart_threshold=multipart_threshold, |
|
|
max_concurrency=max_concurrency, |
|
|
multipart_chunksize=multipart_chunksize, |
|
|
use_threads=use_threads, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs): |
|
|
assert saved_obj is not None, "saved_obj is None!" |
|
|
try: |
|
|
with io.BytesIO() as f: |
|
|
torch.save(saved_obj, f, **kwargs) |
|
|
f.seek(0) |
|
|
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config) |
|
|
except handler.botocore.exceptions.EndpointConnectionError as exc: |
|
|
raise RuntimeError( |
|
|
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" |
|
|
) from exc |
|
|
|
|
|
@staticmethod |
|
|
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict: |
|
|
""" |
|
|
Args: |
|
|
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt |
|
|
""" |
|
|
try: |
|
|
with io.BytesIO() as f: |
|
|
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config) |
|
|
f.seek(0) |
|
|
states = torch.load(f, **kwargs) |
|
|
except handler.botocore.exceptions.EndpointConnectionError as exc: |
|
|
raise RuntimeError( |
|
|
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" |
|
|
) from exc |
|
|
return states |
|
|
|
|
|
@staticmethod |
|
|
def assert_fp_exists(handler, bucket_name: str, fp: str): |
|
|
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp |
|
|
|
|
|
@staticmethod |
|
|
def is_fp_exists(handler, bucket_name: str, fp: str): |
|
|
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp) |
|
|
if "Contents" in re: |
|
|
return len(list(re["Contents"])) > 0 |
|
|
else: |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def get_fns(handler, bucket_name: str, fp: str): |
|
|
""" |
|
|
Ref: https://stackoverflow.com/questions/54314563/ |
|
|
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2 |
|
|
""" |
|
|
if Boto3Client.is_fp_exists(handler, bucket_name, fp): |
|
|
paginator = handler.client.get_paginator("list_objects_v2") |
|
|
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) |
|
|
folder_name_list = [] |
|
|
for page in pages: |
|
|
if "Contents" in page: |
|
|
for obj in page["Contents"]: |
|
|
pth: str = obj["Key"] |
|
|
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) |
|
|
return list(set(folder_name_list)) |
|
|
else: |
|
|
if gpc.is_rank_for_log(): |
|
|
logger.warning(f"'{fp}' not found!") |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): |
|
|
try: |
|
|
with open(local_nvme_path, "rb") as f: |
|
|
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config) |
|
|
except handler.botocore.exceptions.EndpointConnectionError as exc: |
|
|
raise RuntimeError( |
|
|
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" |
|
|
) from exc |
|
|
except Exception as e: |
|
|
raise e |
|
|
|
|
|
@staticmethod |
|
|
def delete_obj(handler, fp: str): |
|
|
raise NotImplementedError("boto3 not support delete_obj") |
|
|
|
|
|
|
|
|
class LocalClient(StorageClient): |
|
|
""" |
|
|
Storage Client for local NFS. |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs) -> None: |
|
|
super().__init__(None) |
|
|
|
|
|
@staticmethod |
|
|
def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs): |
|
|
assert saved_obj is not None |
|
|
fp_dirname = os.path.dirname(fp) |
|
|
if not os.path.exists(fp_dirname): |
|
|
os.makedirs(fp_dirname, exist_ok=True) |
|
|
torch.save(saved_obj, fp, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def load(load_path: str, **kwargs): |
|
|
assert os.path.exists(load_path), f"{load_path} is not found!" |
|
|
with open(load_path, "rb") as f: |
|
|
states = torch.load(f, **kwargs) |
|
|
return states |
|
|
|
|
|
@staticmethod |
|
|
def assert_fp_exists(folder): |
|
|
assert os.path.exists(folder), folder |
|
|
|
|
|
@staticmethod |
|
|
def get_fns(folder): |
|
|
if not os.path.exists(folder): |
|
|
if gpc.is_rank_for_log(): |
|
|
logger.warning(f"'{folder}' not found!") |
|
|
return None |
|
|
else: |
|
|
return os.listdir(folder) |
|
|
|
|
|
@staticmethod |
|
|
def delete_obj(fp: str): |
|
|
if not os.path.isdir(fp): |
|
|
os.remove(fp) |
|
|
|
|
|
|
|
|
def get_tmp_file_name(tmp_local_folder: str, fp: str): |
|
|
""" |
|
|
It should be noted that all our temporary files will be stored in the same folder, |
|
|
so the file name passed upstream must be unique. |
|
|
""" |
|
|
base_path = os.path.join(tmp_local_folder, fp.split("/")[-1]) |
|
|
current_time = datetime.now().strftime("%b%d_%H-%M-%S") |
|
|
pid = os.getpid() |
|
|
|
|
|
return "-".join([base_path, current_time, str(pid)]) + ".tmpfile" |
|
|
|
|
|
|
|
|
def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaInfo: |
|
|
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url" |
|
|
parts = fp.lstrip("s3://").split(os.path.sep) |
|
|
match = boto3_url_re.match(parts[0]) |
|
|
assert match is not None, f"url '{fp}' is not a valid boto3 url" |
|
|
bucket_name, endpoint = match.group(1), match.group(2) |
|
|
endpoint = "http://" + endpoint + ":80" |
|
|
if is_async: |
|
|
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) |
|
|
else: |
|
|
tmp_step_file = None |
|
|
return Boto3MetaInfo( |
|
|
is_async=is_async, |
|
|
handler=None, |
|
|
bucket_name=bucket_name, |
|
|
endpoint=endpoint, |
|
|
file_path=os.path.sep.join(parts[1:]), |
|
|
async_upload_fn=Boto3Client.async_upload_fileobj, |
|
|
local_nvme_path=tmp_step_file, |
|
|
) |
|
|
|
|
|
|
|
|
def get_local_meta(fp: str) -> LocalMetaInfo: |
|
|
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path" |
|
|
return LocalMetaInfo(fp) |
|
|
|
|
|
|
|
|
def get_mount_point_free_size(path: str): |
|
|
""" |
|
|
Returns the remaining space of the temporary storage mount point as a percentage. |
|
|
Args: |
|
|
path (str): temporary storage folder path. |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If the temporary storage folder does not exist, |
|
|
an error will be reported。 |
|
|
""" |
|
|
if os.path.exists(path): |
|
|
st = os.statvfs(path) |
|
|
|
|
|
|
|
|
|
|
|
return st.f_bavail * st.f_bsize / (1024**3) |
|
|
|
|
|
|
|
|
def check_tmp_folder_accessibility(tmp_local_folder: str): |
|
|
""" |
|
|
Check access permissions for temporary storage. |
|
|
""" |
|
|
ret = True |
|
|
if os.path.exists(tmp_local_folder): |
|
|
ret &= os.access(tmp_local_folder, os.W_OK) |
|
|
ret &= os.access(tmp_local_folder, os.R_OK) |
|
|
if ret is False: |
|
|
error_str = f'{socket.gethostname()} dose not have read and write permissions on {tmp_local_folder}"' |
|
|
raise RuntimeError(error_str) |
|
|
|
|
|
|
|
|
class StorageManager(metaclass=SingletonMeta): |
|
|
""" |
|
|
Storage Manager for saving or loading checkpoint. |
|
|
TODO: add a thread to poll the asynchronous storage state. |
|
|
""" |
|
|
|
|
|
BACKEND_TYPE = {"boto3", "local"} |
|
|
BACKEND_INIT_METHOD = { |
|
|
"boto3": Boto3Client, |
|
|
"local": LocalClient, |
|
|
} |
|
|
CLI_DICT = {} |
|
|
|
|
|
def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None: |
|
|
self._exception_list = [] |
|
|
self._to_be_del_files = [] |
|
|
self._async_stack = [] |
|
|
self.upload_count = 0 |
|
|
self.tmp_local_folder = tmp_local_folder |
|
|
self.async_mode = async_mode |
|
|
self.has_warning = False |
|
|
self._async_loop = None |
|
|
self._thread_pool = None |
|
|
self.latest_save_folder = None |
|
|
self.latest_save_step = 0 |
|
|
self.async_task_peeding = False |
|
|
|
|
|
if enable_save and self.async_mode: |
|
|
self._async_loop = asyncio.new_event_loop() |
|
|
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers) |
|
|
|
|
|
check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder)) |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(self.tmp_local_folder, exist_ok=True) |
|
|
os.chmod(self.tmp_local_folder, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
except FileExistsError: |
|
|
pass |
|
|
|
|
|
|
|
|
check_tmp_folder_accessibility(self.tmp_local_folder) |
|
|
|
|
|
|
|
|
self.try_delete_tmpfile(self.tmp_local_folder) |
|
|
|
|
|
|
|
|
free_size = get_mount_point_free_size(self.tmp_local_folder) |
|
|
if free_size < 0.1: |
|
|
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!') |
|
|
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}") |
|
|
|
|
|
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]: |
|
|
""" |
|
|
tools: |
|
|
local:/path/to/checkpoint |
|
|
boto3:s3://model_weights/0331/120bi |
|
|
|
|
|
Args: |
|
|
path (str): _description_ |
|
|
""" |
|
|
backend, path = try_get_storage_backend(path) |
|
|
|
|
|
init_args = (None,) |
|
|
if backend == "local": |
|
|
meta_info = get_local_meta(path) |
|
|
backend_key = backend |
|
|
elif backend == "boto3": |
|
|
meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode) |
|
|
backend_key = backend + ":" + meta_info.endpoint |
|
|
init_args = (meta_info.endpoint,) |
|
|
if ( |
|
|
"http_proxy" in os.environ |
|
|
or "https_proxy" in os.environ |
|
|
or "HTTP_PROXY" in os.environ |
|
|
or "HTTPS_PROXY" in os.environ |
|
|
): |
|
|
if not self.has_warning: |
|
|
logger.warning( |
|
|
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ |
|
|
the proxy may make boto3 unavailable or affect performance." |
|
|
) |
|
|
self.has_warning = True |
|
|
|
|
|
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}" |
|
|
|
|
|
|
|
|
if backend_key not in StorageManager.CLI_DICT: |
|
|
StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)}) |
|
|
|
|
|
meta_info.client = StorageManager.CLI_DICT[backend_key] |
|
|
|
|
|
return meta_info |
|
|
|
|
|
def assert_fp_exists(self, folder) -> None: |
|
|
meta = self._get_client(path=folder) |
|
|
meta.client.assert_fp_exists(*unpack_nosave_meta(meta)) |
|
|
|
|
|
def get_fns(self, folder) -> List[str]: |
|
|
meta = self._get_client(path=folder) |
|
|
return meta.client.get_fns(*unpack_nosave_meta(meta)) |
|
|
|
|
|
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs): |
|
|
|
|
|
if async_upload is None: |
|
|
async_upload = self.async_mode |
|
|
|
|
|
if not save_path.startswith("boto3:"): |
|
|
async_upload = False |
|
|
|
|
|
meta = self._get_client(save_path, async_upload) |
|
|
|
|
|
if async_upload: |
|
|
assert ( |
|
|
self.tmp_local_folder |
|
|
), "StorageManager is not setted tmp_local_folder, so async save cannot be performed." |
|
|
tmp_step_file = meta.local_nvme_path |
|
|
self._to_be_del_files.append(tmp_step_file) |
|
|
with open(tmp_step_file, "wb") as f: |
|
|
torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) |
|
|
self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta)) |
|
|
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
self.async_task_peeding = True |
|
|
else: |
|
|
meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs) |
|
|
self.upload_count += 1 |
|
|
|
|
|
def load(self, load_path: str, **kwargs) -> Any: |
|
|
self.wait() |
|
|
meta = self._get_client(path=load_path) |
|
|
return meta.client.load(*unpack_nosave_meta(meta), **kwargs) |
|
|
|
|
|
def delete_obj(self, fp: str): |
|
|
meta = self._get_client(path=fp) |
|
|
meta.client.delete_obj(*unpack_nosave_meta(meta)) |
|
|
|
|
|
def _del_tmp_folder(self): |
|
|
for fp in self._to_be_del_files: |
|
|
try: |
|
|
os.remove(fp) |
|
|
except FileNotFoundError: |
|
|
pass |
|
|
except SystemError as e: |
|
|
logger.error(f'delete file: {fp}, failed for reason:"{e}"') |
|
|
else: |
|
|
pass |
|
|
|
|
|
def try_delete_tmpfile(self, tmp_dir: str): |
|
|
"""Delete temporary files in tmp_dir.""" |
|
|
|
|
|
for filename in os.listdir(tmp_dir): |
|
|
if filename.endswith(".tmpfile"): |
|
|
file_path = os.path.join(tmp_dir, filename) |
|
|
try: |
|
|
os.remove(file_path) |
|
|
logger.info(f"Delete tmpfile: {file_path}") |
|
|
except OSError: |
|
|
|
|
|
pass |
|
|
|
|
|
async def _sync_tasks(self) -> Awaitable[None]: |
|
|
if self._async_stack: |
|
|
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED) |
|
|
count = 0 |
|
|
while self._async_stack: |
|
|
t = self._async_stack[0] |
|
|
try: |
|
|
e = t.exception() |
|
|
if e: |
|
|
self._exception_list.append((e, count)) |
|
|
logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}") |
|
|
|
|
|
count += 1 |
|
|
self._async_stack.pop(0) |
|
|
except InvalidStateError: |
|
|
|
|
|
pass |
|
|
|
|
|
def async_executor(self, fn: Callable, *args, **kwargs) -> None: |
|
|
""" |
|
|
Overview: |
|
|
Execute task in background, then apppend the future instance in _async_stack. |
|
|
Arguments: |
|
|
- fn (:obj:`Callable`): Synchronization fuction. |
|
|
""" |
|
|
if not self._async_loop: |
|
|
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode") |
|
|
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) |
|
|
self._async_stack.append(t) |
|
|
|
|
|
def wait(self) -> bool: |
|
|
"""Wait for async operations to complete.""" |
|
|
|
|
|
if not self.async_mode: |
|
|
return |
|
|
|
|
|
if not self.async_task_peeding: |
|
|
return |
|
|
|
|
|
if self._async_loop: |
|
|
self._async_loop.run_until_complete(self._sync_tasks()) |
|
|
|
|
|
if self._exception_list: |
|
|
for error_msg, file_id in self._exception_list: |
|
|
logger.error( |
|
|
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} " |
|
|
f"failed on step {self.upload_count}: {error_msg}" |
|
|
) |
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
|
f"Failed to upload {self._to_be_del_files[file_id]} " f"on step {self.upload_count}: {error_msg}" |
|
|
) |
|
|
|
|
|
self._del_tmp_folder() |
|
|
self._exception_list.clear() |
|
|
self._to_be_del_files.clear() |
|
|
self.async_task_peeding = False |
|
|
|
|
|
if gpc.is_rank_for_log(): |
|
|
self.upload_count += 1 |
|
|
if self.async_mode and self.latest_save_folder: |
|
|
self.save( |
|
|
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), |
|
|
to_save_obj=dict({"step": self.latest_save_step}), |
|
|
async_upload=False, |
|
|
) |
|
|
self.latest_save_folder = None |
|
|
|
|
|
|
|
|
storage_manager: StorageManager = None |
|
|
|
|
|
|
|
|
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload): |
|
|
global storage_manager |
|
|
storage_manager = StorageManager( |
|
|
enable_save_ckpt, |
|
|
tmp_local_folder=async_upload_tmp_folder, |
|
|
async_mode=async_upload, |
|
|
) |
|
|
|
|
|
|
|
|
def get_storage_manager(): |
|
|
assert storage_manager is not None, "storage_manager has not been init!" |
|
|
return storage_manager |
|
|
|
|
|
|
|
|
def wait_async_upload_finish(): |
|
|
dist.barrier() |
|
|
storage_manager.wait() |
|
|
|