b-re-w's picture
Upload merge.py
3e04ca0 verified
import torch
import torch.distributed.tensor
from safetensors.torch import save_file
import os
from collections import OrderedDict
import gc
def merge_fsdp_to_safetensors(rank0_path, rank1_path, output_path, target_dtype=None):
"""
FSDP๋กœ ๋ถ„ํ• ๋œ ๋‘ ๊ฐœ์˜ .pt ํŒŒ์ผ์„ ํ•˜๋‚˜์˜ .safetensors ํŒŒ์ผ๋กœ ๋ณ‘ํ•ฉ
Args:
rank0_path (str): rank 0 .pt ํŒŒ์ผ ๊ฒฝ๋กœ
rank1_path (str): rank 1 .pt ํŒŒ์ผ ๊ฒฝ๋กœ
output_path (str): ์ถœ๋ ฅํ•  .safetensors ํŒŒ์ผ ๊ฒฝ๋กœ
target_dtype (torch.dtype, optional): ํƒ€๊ฒŸ dtype (์˜ˆ: torch.float16, torch.bfloat16)
"""
print("Loading rank 0 checkpoint...")
rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False)
print("Loading rank 1 checkpoint...")
rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False)
# DTensor๋ฅผ ์ผ๋ฐ˜ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
def convert_dtensor_to_tensor(state_dict):
converted_dict = OrderedDict()
dtype_info = {}
for key, value in state_dict.items():
if hasattr(value, '_local_tensor'):
# DTensor์ธ ๊ฒฝ์šฐ ๋กœ์ปฌ ํ…์„œ ์ถ”์ถœ
tensor = value._local_tensor
converted_dict[key] = tensor
dtype_info[key] = tensor.dtype
print(f"Converted DTensor to tensor: {key} (dtype: {tensor.dtype})")
elif isinstance(value, torch.Tensor):
converted_dict[key] = value
dtype_info[key] = value.dtype
else:
# ๋‹ค๋ฅธ ํƒ€์ž…์€ ๊ทธ๋Œ€๋กœ ์œ ์ง€
converted_dict[key] = value
dtype_info[key] = type(value).__name__
return converted_dict, dtype_info
print("Converting DTensors to regular tensors...")
rank0_dict, rank0_dtypes = convert_dtensor_to_tensor(rank0_dict)
rank1_dict, rank1_dtypes = convert_dtensor_to_tensor(rank1_dict)
# dtype ์ •๋ณด ์ถœ๋ ฅ
print("\n๐Ÿ“‹ Original dtype information:")
all_dtypes_r0 = set(dtype_info for dtype_info in rank0_dtypes.values() if isinstance(dtype_info, torch.dtype))
all_dtypes_r1 = set(dtype_info for dtype_info in rank1_dtypes.values() if isinstance(dtype_info, torch.dtype))
all_dtypes = all_dtypes_r0 | all_dtypes_r1
print(f" Rank 0 dtypes found: {all_dtypes_r0}")
print(f" Rank 1 dtypes found: {all_dtypes_r1}")
print(f" All dtypes: {all_dtypes}")
if target_dtype:
print(f" Target dtype specified: {target_dtype}")
else:
print(" No target dtype specified - keeping original dtypes")
# ๋ณ‘ํ•ฉ๋œ ์ƒํƒœ ์‚ฌ์ „
merged_state_dict = OrderedDict()
# rank 0์˜ ๋ชจ๋“  ํ‚ค๋“ค์„ ๋จผ์ € ์ฒ˜๋ฆฌ
all_keys = set(rank0_dict.keys()) | set(rank1_dict.keys())
print(f"Total unique keys found: {len(all_keys)}")
for key in sorted(all_keys):
rank0_tensor = rank0_dict.get(key)
rank1_tensor = rank1_dict.get(key)
if rank0_tensor is not None and rank1_tensor is not None:
# ๋‘ rank์— ๋ชจ๋‘ ์กด์žฌํ•˜๋Š” ๊ฒฝ์šฐ - ์ฐจ์› ํ™•์ธ ํ›„ ์—ฐ๊ฒฐ
print(f"Merging key: {key}")
print(f" Rank 0 shape: {rank0_tensor.shape}, dtype: {rank0_tensor.dtype}")
print(f" Rank 1 shape: {rank1_tensor.shape}, dtype: {rank1_tensor.dtype}")
# dtype ๋ณ€ํ™˜ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
if target_dtype and rank0_tensor.dtype != target_dtype:
rank0_tensor = rank0_tensor.to(target_dtype)
print(f" Converted rank 0 to {target_dtype}")
if target_dtype and rank1_tensor.dtype != target_dtype:
rank1_tensor = rank1_tensor.to(target_dtype)
print(f" Converted rank 1 to {target_dtype}")
# ์ฒซ ๋ฒˆ์งธ ์ฐจ์›์œผ๋กœ ์—ฐ๊ฒฐ (์ผ๋ฐ˜์ ์ธ FSDP ์ƒค๋”ฉ ๋ฐฉ์‹)
merged_tensor = torch.cat([rank0_tensor, rank1_tensor], dim=0)
merged_state_dict[key] = merged_tensor
print(f" Merged shape: {merged_tensor.shape}, dtype: {merged_tensor.dtype}")
elif rank0_tensor is not None:
# rank 0์—๋งŒ ์กด์žฌ
tensor = rank0_tensor
if target_dtype and isinstance(tensor, torch.Tensor) and tensor.dtype != target_dtype:
tensor = tensor.to(target_dtype)
print(f"Converting {key} from rank 0: {rank0_tensor.dtype} -> {target_dtype}")
print(f"Adding from rank 0: {key} (shape: {tensor.shape if isinstance(tensor, torch.Tensor) else 'N/A'}, dtype: {tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor).__name__})")
merged_state_dict[key] = tensor
elif rank1_tensor is not None:
# rank 1์—๋งŒ ์กด์žฌ
tensor = rank1_tensor
if target_dtype and isinstance(tensor, torch.Tensor) and tensor.dtype != target_dtype:
tensor = tensor.to(target_dtype)
print(f"Converting {key} from rank 1: {rank1_tensor.dtype} -> {target_dtype}")
print(f"Adding from rank 1: {key} (shape: {tensor.shape if isinstance(tensor, torch.Tensor) else 'N/A'}, dtype: {tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor).__name__})")
merged_state_dict[key] = tensor
print(f"\nTotal merged parameters: {len(merged_state_dict)}")
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del rank0_dict, rank1_dict
gc.collect()
# safetensors๋กœ ์ €์žฅ
print(f"Saving merged model to {output_path}...")
# ์ตœ์ข… dtype ์ •๋ณด ์ถœ๋ ฅ
final_dtypes = {}
for key, tensor in merged_state_dict.items():
if isinstance(tensor, torch.Tensor):
final_dtypes[tensor.dtype] = final_dtypes.get(tensor.dtype, 0) + 1
print(f"๐Ÿ“‹ Final merged model dtype distribution:")
for dtype, count in final_dtypes.items():
print(f" {dtype}: {count} tensors")
save_file(merged_state_dict, output_path)
print("โœ… Successfully saved merged model!")
return merged_state_dict
def merge_with_custom_concatenation(rank0_path, rank1_path, output_path, concat_rules=None):
"""
์‚ฌ์šฉ์ž ์ •์˜ ์—ฐ๊ฒฐ ๊ทœ์น™์œผ๋กœ ๋ณ‘ํ•ฉ
Args:
concat_rules (dict): ํ‚ค๋ณ„ ์—ฐ๊ฒฐ ์ฐจ์› ์ง€์ • {'key_pattern': dim}
"""
if concat_rules is None:
# ๊ธฐ๋ณธ ๊ทœ์น™
concat_rules = {
'weight': 0, # ๊ฐ€์ค‘์น˜๋Š” ์ฒซ ๋ฒˆ์งธ ์ฐจ์›์œผ๋กœ ์—ฐ๊ฒฐ
'bias': 0, # ํŽธํ–ฅ๋„ ์ฒซ ๋ฒˆ์งธ ์ฐจ์›์œผ๋กœ ์—ฐ๊ฒฐ
}
print("Loading checkpoints...")
rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False)
rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False)
merged_state_dict = OrderedDict()
all_keys = set(rank0_dict.keys()) | set(rank1_dict.keys())
for key in sorted(all_keys):
rank0_tensor = rank0_dict.get(key)
rank1_tensor = rank1_dict.get(key)
if rank0_tensor is not None and rank1_tensor is not None:
# ์—ฐ๊ฒฐ ์ฐจ์› ๊ฒฐ์ •
concat_dim = 0 # ๊ธฐ๋ณธ๊ฐ’
for pattern, dim in concat_rules.items():
if pattern in key:
concat_dim = dim
break
print(f"Merging {key} along dimension {concat_dim}")
merged_tensor = torch.cat([rank0_tensor, rank1_tensor], dim=concat_dim)
merged_state_dict[key] = merged_tensor
elif rank0_tensor is not None:
merged_state_dict[key] = rank0_tensor
elif rank1_tensor is not None:
merged_state_dict[key] = rank1_tensor
# ์ •๋ฆฌ ๋ฐ ์ €์žฅ
del rank0_dict, rank1_dict
gc.collect()
print(f"Saving to {output_path}...")
save_file(merged_state_dict, output_path)
print("โœ… Merge completed!")
def comprehensive_verification(rank0_path, rank1_path, merged_path):
"""๋ณ‘ํ•ฉ์ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๋˜์—ˆ๋Š”์ง€ ์ข…ํ•ฉ์ ์œผ๋กœ ๊ฒ€์ฆ"""
import torch.distributed.tensor
from safetensors import safe_open
print("๐Ÿ” Starting comprehensive verification...")
# 1. ์›๋ณธ ํŒŒ์ผ๋“ค ๋กœ๋“œ
print("\n๐Ÿ“ Loading original files...")
rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False)
rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False)
# DTensor๋ฅผ ์ผ๋ฐ˜ ํ…์„œ๋กœ ๋ณ€ํ™˜
def convert_dtensor_to_tensor(state_dict):
converted_dict = {}
for key, value in state_dict.items():
if hasattr(value, '_local_tensor'):
converted_dict[key] = value._local_tensor
elif isinstance(value, torch.Tensor):
converted_dict[key] = value
else:
converted_dict[key] = value
return converted_dict
rank0_dict = convert_dtensor_to_tensor(rank0_dict)
rank1_dict = convert_dtensor_to_tensor(rank1_dict)
# 2. ์›๋ณธ ํŒŒ์ผ๋“ค ๋ถ„์„
rank0_keys = set(rank0_dict.keys())
rank1_keys = set(rank1_dict.keys())
all_original_keys = rank0_keys | rank1_keys
shared_keys = rank0_keys & rank1_keys
rank0_only = rank0_keys - rank1_keys
rank1_only = rank1_keys - rank0_keys
print(f"๐Ÿ“Š Original files analysis:")
print(f" Rank 0 keys: {len(rank0_keys)}")
print(f" Rank 1 keys: {len(rank1_keys)}")
print(f" Shared keys: {len(shared_keys)}")
print(f" Rank 0 only: {len(rank0_only)}")
print(f" Rank 1 only: {len(rank1_only)}")
print(f" Total unique keys: {len(all_original_keys)}")
# 3. ์›๋ณธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ
original_params = 0
original_shapes = {}
for key in all_original_keys:
if key in shared_keys:
# ๊ณต์œ  ํ‚ค๋Š” ๋‘ ํ…์„œ๋ฅผ ์—ฐ๊ฒฐํ•œ ํฌ๊ธฐ๋กœ ๊ณ„์‚ฐ
r0_tensor = rank0_dict[key]
r1_tensor = rank1_dict[key]
combined_shape = list(r0_tensor.shape)
combined_shape[0] += r1_tensor.shape[0] # ์ฒซ ๋ฒˆ์งธ ์ฐจ์›์œผ๋กœ ์—ฐ๊ฒฐ ๊ฐ€์ •
original_shapes[key] = tuple(combined_shape)
original_params += r0_tensor.numel() + r1_tensor.numel()
elif key in rank0_only:
original_shapes[key] = rank0_dict[key].shape
original_params += rank0_dict[key].numel()
elif key in rank1_only:
original_shapes[key] = rank1_dict[key].shape
original_params += rank1_dict[key].numel()
print(f" Original total parameters: {original_params:,}")
# 4. ๋ณ‘ํ•ฉ๋œ ํŒŒ์ผ ๋ถ„์„
print(f"\n๐Ÿ“ Loading merged file: {merged_path}")
merged_params = 0
merged_keys = set()
merged_shapes = {}
with safe_open(merged_path, framework="pt", device="cpu") as f:
merged_keys = set(f.keys())
for key in f.keys():
tensor = f.get_tensor(key)
merged_shapes[key] = tensor.shape
merged_params += tensor.numel()
print(f"๐Ÿ“Š Merged file analysis:")
print(f" Merged keys: {len(merged_keys)}")
print(f" Merged parameters: {merged_params:,}")
# 5. ๋น„๊ต ๋ฐ ๊ฒ€์ฆ
print(f"\nโœ… Verification Results:")
# ํ‚ค ๊ฐœ์ˆ˜ ๋น„๊ต
keys_match = len(merged_keys) == len(all_original_keys)
print(f" Keys count match: {keys_match} ({len(merged_keys)} vs {len(all_original_keys)})")
# ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๋น„๊ต
params_match = merged_params == original_params
print(f" Parameter count match: {params_match} ({merged_params:,} vs {original_params:,})")
# ํ‚ค ์ด๋ฆ„ ๋น„๊ต
missing_keys = all_original_keys - merged_keys
extra_keys = merged_keys - all_original_keys
if missing_keys:
print(f" โŒ Missing keys: {missing_keys}")
if extra_keys:
print(f" โŒ Extra keys: {extra_keys}")
# ๊ฐœ๋ณ„ ํ…์„œ ํฌ๊ธฐ ๋น„๊ต
shape_mismatches = []
for key in merged_keys & all_original_keys:
if merged_shapes[key] != original_shapes[key]:
shape_mismatches.append((key, merged_shapes[key], original_shapes[key]))
if shape_mismatches:
print(f" โŒ Shape mismatches:")
for key, merged_shape, original_shape in shape_mismatches[:5]: # ์ฒ˜์Œ 5๊ฐœ๋งŒ ํ‘œ์‹œ
print(f" {key}: {merged_shape} vs {original_shape}")
if len(shape_mismatches) > 5:
print(f" ... and {len(shape_mismatches) - 5} more")
# 6. ์„ธ๋ถ€ ๋ถ„์„ (์„ ํƒ์ )
print(f"\n๐Ÿ“‹ Detailed Analysis:")
print(f" Shared keys that should be concatenated:")
for key in sorted(list(shared_keys))[:10]: # ์ฒ˜์Œ 10๊ฐœ๋งŒ ํ‘œ์‹œ
r0_shape = rank0_dict[key].shape
r1_shape = rank1_dict[key].shape
expected_shape = list(r0_shape)
expected_shape[0] += r1_shape[0]
actual_shape = merged_shapes.get(key, "MISSING")
status = "โœ…" if tuple(expected_shape) == actual_shape else "โŒ"
print(f" {status} {key}: {r0_shape} + {r1_shape} -> {actual_shape}")
if len(shared_keys) > 10:
print(f" ... and {len(shared_keys) - 10} more shared keys")
# 7. ์ตœ์ข… ๊ฒฐ๊ณผ
overall_success = keys_match and params_match and not missing_keys and not extra_keys and not shape_mismatches
print(f"\n{'='*50}")
if overall_success:
print("๐ŸŽ‰ MERGE VERIFICATION SUCCESSFUL!")
print(" All parameters have been correctly merged.")
else:
print("โš ๏ธ MERGE VERIFICATION FOUND ISSUES!")
print(" Please review the mismatches above.")
print(f"{'='*50}")
# ์ •๋ฆฌ
del rank0_dict, rank1_dict
gc.collect()
return overall_success
# ์‚ฌ์šฉ ์˜ˆ์‹œ
if __name__ == "__main__":
# ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
rank0_file = "model_rank_0.pt" # ์‹ค์ œ ํŒŒ์ผ๋ช…์œผ๋กœ ๋ณ€๊ฒฝ
rank1_file = "model_rank_1.pt" # ์‹ค์ œ ํŒŒ์ผ๋ช…์œผ๋กœ ๋ณ€๊ฒฝ
output_file = "merged_model.safetensors"
# dtype ์˜ต์…˜ ์„ค์ •
target_dtype = torch.bfloat16 # bf16์œผ๋กœ ๋ณ€ํ™˜
# ๊ธฐ๋ณธ ๋ณ‘ํ•ฉ
print("Starting merge process...")
merged_dict = merge_fsdp_to_safetensors(rank0_file, rank1_file, output_file, target_dtype)
# ์ข…ํ•ฉ์ ์ธ ๊ฒ€์ฆ
print("\nStarting comprehensive verification...")
verification_passed = comprehensive_verification(rank0_file, rank1_file, output_file)
if verification_passed:
print(f"\n๐ŸŽ‰ Successfully merged and verified FSDP model to {output_file}")
else:
print(f"\nโš ๏ธ Merge completed but verification found issues. Please review the output above.")
# ์ถ”๊ฐ€: ๊ฐ„๋‹จํ•œ ๋กœ๋“œ ํ…Œ์ŠคํŠธ
print(f"\n๐Ÿ” Testing if merged model can be loaded...")
try:
from safetensors import safe_open
with safe_open(output_file, framework="pt", device="cpu") as f:
sample_keys = list(f.keys())[:3]
for key in sample_keys:
tensor = f.get_tensor(key)
print(f" โœ… Successfully loaded {key}: {tensor.shape}, dtype: {tensor.dtype}")
print(" โœ… Merged model loads correctly!")
except Exception as e:
print(f" โŒ Error loading merged model: {e}")
print(f"\n๐Ÿ’ก Tip: To change dtype, modify 'target_dtype' in the script:")
print(f" - torch.float16 for fp16 (smaller file, less precision)")
print(f" - torch.bfloat16 for bf16 (good balance)")
print(f" - torch.float32 for fp32 (larger file, best precision)")
print(f" - None to keep original dtypes")