Xin-Rui's picture
Upload folder using huggingface_hub
7155cf2 verified
# model_merger.py
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, Optional
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
PreTrainedModel,
)
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement) -> torch.Tensor:
"""Merge tensors according to their placement."""
if placement.is_replicate():
return tensors[0]
elif placement.is_partial():
raise NotImplementedError("Partial placement is not supported yet")
elif placement.is_shard():
return torch.cat(tensors, dim=placement.dim).contiguous()
else:
raise ValueError(f"Unsupported placement: {placement}")
def get_model_class(config: AutoConfig) -> PreTrainedModel:
"""Determine the appropriate model class based on config."""
if "ForTokenClassification" in config.architectures[0]:
return AutoModelForTokenClassification
elif "ForCausalLM" in config.architectures[0]:
return AutoModelForCausalLM
elif "ForConditionalGeneration" in config.architectures[0]:
return AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {config.architectures}")
def load_sharded_state_dicts(local_dir: str) -> Tuple[List[dict], int, Tuple[int, ...], Tuple[str, ...]]:
"""Load all sharded state dicts and return mesh information."""
# Find world size and rank 0 file
world_size = 0
for filename in os.listdir(local_dir):
match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
if match:
world_size = int(match.group(1))
break
if not world_size:
raise ValueError("No model file with the proper format found")
# Load rank 0 to get mesh info
rank0_state = torch.load(
os.path.join(local_dir, f"model_world_size_{world_size}_rank_0.pt"),
map_location="cpu"
)
pivot_key = sorted(rank0_state.keys())[0]
weight = rank0_state[pivot_key]
if not isinstance(weight, DTensor):
raise TypeError("Expected DTensor in state dict")
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
if mesh_dim_names not in (("fsdp",),):
raise ValueError(f"Unsupported mesh_dim_names {mesh_dim_names}")
# Prepare list for all state dicts
state_dicts = [rank0_state] + [None] * (world_size - 1)
# Load remaining shards in parallel
def load_shard(rank):
if rank == 0:
return rank0_state
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
return torch.load(model_path, map_location="cpu", weights_only=False)
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank, state_dict in enumerate(executor.map(load_shard, range(world_size))):
state_dicts[rank] = state_dict
return state_dicts, world_size, mesh.shape, mesh_dim_names
def merge_state_dicts(
state_dicts: List[dict],
world_size: int,
mesh_shape: Tuple[int, ...],
mesh_dim_names: Tuple[str, ...]
) -> dict:
"""Merge sharded state dicts into a single state dict."""
merged_state = {}
param_placements: Dict[str, List[Placement]] = {}
keys = set(state_dicts[0].keys())
for key in keys:
shards = []
for state_dict in state_dicts:
tensor = state_dict[key]
if isinstance(tensor, DTensor):
shards.append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
# Handle replicated placement at dp dimension
if mesh_dim_names[0] == "dp":
placements = placements[1:]
if key not in param_placements:
param_placements[key] = placements
else:
assert param_placements[key] == placements
else:
# Non-DTensor values (like buffers) are the same across ranks
merged_state[key] = tensor.bfloat16()
break
if key in merged_state:
continue
# Merge shards according to their placements
placements = param_placements[key]
if len(mesh_shape) == 1:
# 1-D sharding (FSDP only)
assert len(placements) == 1
merged_state[key] = merge_by_placement(shards, placements[0])
else:
# 2-D sharding (FSDP + TP)
raise NotImplementedError("FSDP + TP is not supported yet")
return merged_state
def save_merged_model(
local_dir: str,
merged_state: dict,
hf_upload_path: Optional[str] = None
) -> None:
"""Save merged model and optionally upload to Hugging Face Hub."""
hf_path = os.path.join(local_dir, "huggingface")
config = AutoConfig.from_pretrained(hf_path)
model_class = get_model_class(config)
# Create model on meta device first to save memory
with torch.device("meta"):
model = model_class.from_config(config, torch_dtype=torch.bfloat16)
# Load state dict onto CPU
model.to_empty(device="cpu")
model.load_state_dict(merged_state)
print(f"Saving model to {hf_path}")
model.save_pretrained(hf_path)
if hf_upload_path:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=hf_upload_path, private=False, exist_ok=True)
api.upload_folder(folder_path=hf_path, repo_id=hf_upload_path, repo_type="model")
def merge_and_save_model(local_dir: str, hf_upload_path: Optional[str] = None) -> None:
"""Main function to merge sharded models and save the result."""
# Load all sharded state dicts
state_dicts, world_size, mesh_shape, mesh_dim_names = load_sharded_state_dicts(local_dir)
# Merge state dicts
merged_state = merge_state_dicts(state_dicts, world_size, mesh_shape, mesh_dim_names)
# Save merged model
save_merged_model(local_dir, merged_state, hf_upload_path)
# reorganize_folders(local_dir)
import os
import shutil
from pathlib import Path
def reorganize_folders(root_dir: str) -> None:
"""
重组文件夹结构:
1. 将actor/huggingface重命名为models并移动到父目录
2. 删除除新models文件夹外的所有内容
参数:
root_dir: 最外层目录路径 (示例中的'step_20_reward_0.676')
"""
root_path = Path(root_dir)
actor_path = root_path / "actor"
huggingface_path = actor_path / "huggingface"
# 验证目录结构是否符合预期
if not actor_path.exists():
raise FileNotFoundError(f"未找到actor目录: {actor_path}")
if not huggingface_path.exists():
raise FileNotFoundError(f"未找到huggingface目录: {huggingface_path}")
# 新models目录路径 (与actor同级)
models_path = root_path / "models"
print(f"正在将 {huggingface_path} 移动到 {models_path}")
# 移动并重命名huggingface文件夹
shutil.move(str(huggingface_path), str(models_path))
print("正在清理原始文件...")
# 删除原始actor目录及其内容
shutil.rmtree(str(actor_path))
# 删除其他可能存在的文件 (根据图片描述)
for item in root_path.glob("*"):
if item.name != "models":
if item.is_file():
item.unlink()
elif item.is_dir():
shutil.rmtree(str(item))
print("文件夹重组完成!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", required=False, type=str, help="The path for your saved model")
parser.add_argument("--hf_upload_path", default=None, type=str,
help="The path of the huggingface repo to upload")
args = parser.parse_args()
merge_and_save_model("/mnt/lyc/wuxinrui/R1_training/training/TCM4_addthinkprunedata/step_17_reward_0.668/actor")
reorganize_folders("/mnt/lyc/wuxinrui/R1_training/training/TCM4_addthinkprunedata/step_17_reward_0.668")