|
|
import torch
|
|
|
import torch.distributed.tensor
|
|
|
from safetensors.torch import save_file
|
|
|
import os
|
|
|
from collections import OrderedDict
|
|
|
import gc
|
|
|
|
|
|
def merge_fsdp_to_safetensors(rank0_path, rank1_path, output_path, target_dtype=None):
|
|
|
"""
|
|
|
FSDP๋ก ๋ถํ ๋ ๋ ๊ฐ์ .pt ํ์ผ์ ํ๋์ .safetensors ํ์ผ๋ก ๋ณํฉ
|
|
|
|
|
|
Args:
|
|
|
rank0_path (str): rank 0 .pt ํ์ผ ๊ฒฝ๋ก
|
|
|
rank1_path (str): rank 1 .pt ํ์ผ ๊ฒฝ๋ก
|
|
|
output_path (str): ์ถ๋ ฅํ .safetensors ํ์ผ ๊ฒฝ๋ก
|
|
|
target_dtype (torch.dtype, optional): ํ๊ฒ dtype (์: torch.float16, torch.bfloat16)
|
|
|
"""
|
|
|
print("Loading rank 0 checkpoint...")
|
|
|
rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False)
|
|
|
|
|
|
print("Loading rank 1 checkpoint...")
|
|
|
rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False)
|
|
|
|
|
|
|
|
|
def convert_dtensor_to_tensor(state_dict):
|
|
|
converted_dict = OrderedDict()
|
|
|
dtype_info = {}
|
|
|
for key, value in state_dict.items():
|
|
|
if hasattr(value, '_local_tensor'):
|
|
|
|
|
|
tensor = value._local_tensor
|
|
|
converted_dict[key] = tensor
|
|
|
dtype_info[key] = tensor.dtype
|
|
|
print(f"Converted DTensor to tensor: {key} (dtype: {tensor.dtype})")
|
|
|
elif isinstance(value, torch.Tensor):
|
|
|
converted_dict[key] = value
|
|
|
dtype_info[key] = value.dtype
|
|
|
else:
|
|
|
|
|
|
converted_dict[key] = value
|
|
|
dtype_info[key] = type(value).__name__
|
|
|
return converted_dict, dtype_info
|
|
|
|
|
|
print("Converting DTensors to regular tensors...")
|
|
|
rank0_dict, rank0_dtypes = convert_dtensor_to_tensor(rank0_dict)
|
|
|
rank1_dict, rank1_dtypes = convert_dtensor_to_tensor(rank1_dict)
|
|
|
|
|
|
|
|
|
print("\n๐ Original dtype information:")
|
|
|
all_dtypes_r0 = set(dtype_info for dtype_info in rank0_dtypes.values() if isinstance(dtype_info, torch.dtype))
|
|
|
all_dtypes_r1 = set(dtype_info for dtype_info in rank1_dtypes.values() if isinstance(dtype_info, torch.dtype))
|
|
|
all_dtypes = all_dtypes_r0 | all_dtypes_r1
|
|
|
|
|
|
print(f" Rank 0 dtypes found: {all_dtypes_r0}")
|
|
|
print(f" Rank 1 dtypes found: {all_dtypes_r1}")
|
|
|
print(f" All dtypes: {all_dtypes}")
|
|
|
|
|
|
if target_dtype:
|
|
|
print(f" Target dtype specified: {target_dtype}")
|
|
|
else:
|
|
|
print(" No target dtype specified - keeping original dtypes")
|
|
|
|
|
|
|
|
|
merged_state_dict = OrderedDict()
|
|
|
|
|
|
|
|
|
all_keys = set(rank0_dict.keys()) | set(rank1_dict.keys())
|
|
|
|
|
|
print(f"Total unique keys found: {len(all_keys)}")
|
|
|
|
|
|
for key in sorted(all_keys):
|
|
|
rank0_tensor = rank0_dict.get(key)
|
|
|
rank1_tensor = rank1_dict.get(key)
|
|
|
|
|
|
if rank0_tensor is not None and rank1_tensor is not None:
|
|
|
|
|
|
print(f"Merging key: {key}")
|
|
|
print(f" Rank 0 shape: {rank0_tensor.shape}, dtype: {rank0_tensor.dtype}")
|
|
|
print(f" Rank 1 shape: {rank1_tensor.shape}, dtype: {rank1_tensor.dtype}")
|
|
|
|
|
|
|
|
|
if target_dtype and rank0_tensor.dtype != target_dtype:
|
|
|
rank0_tensor = rank0_tensor.to(target_dtype)
|
|
|
print(f" Converted rank 0 to {target_dtype}")
|
|
|
if target_dtype and rank1_tensor.dtype != target_dtype:
|
|
|
rank1_tensor = rank1_tensor.to(target_dtype)
|
|
|
print(f" Converted rank 1 to {target_dtype}")
|
|
|
|
|
|
|
|
|
merged_tensor = torch.cat([rank0_tensor, rank1_tensor], dim=0)
|
|
|
merged_state_dict[key] = merged_tensor
|
|
|
print(f" Merged shape: {merged_tensor.shape}, dtype: {merged_tensor.dtype}")
|
|
|
|
|
|
elif rank0_tensor is not None:
|
|
|
|
|
|
tensor = rank0_tensor
|
|
|
if target_dtype and isinstance(tensor, torch.Tensor) and tensor.dtype != target_dtype:
|
|
|
tensor = tensor.to(target_dtype)
|
|
|
print(f"Converting {key} from rank 0: {rank0_tensor.dtype} -> {target_dtype}")
|
|
|
print(f"Adding from rank 0: {key} (shape: {tensor.shape if isinstance(tensor, torch.Tensor) else 'N/A'}, dtype: {tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor).__name__})")
|
|
|
merged_state_dict[key] = tensor
|
|
|
|
|
|
elif rank1_tensor is not None:
|
|
|
|
|
|
tensor = rank1_tensor
|
|
|
if target_dtype and isinstance(tensor, torch.Tensor) and tensor.dtype != target_dtype:
|
|
|
tensor = tensor.to(target_dtype)
|
|
|
print(f"Converting {key} from rank 1: {rank1_tensor.dtype} -> {target_dtype}")
|
|
|
print(f"Adding from rank 1: {key} (shape: {tensor.shape if isinstance(tensor, torch.Tensor) else 'N/A'}, dtype: {tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor).__name__})")
|
|
|
merged_state_dict[key] = tensor
|
|
|
|
|
|
print(f"\nTotal merged parameters: {len(merged_state_dict)}")
|
|
|
|
|
|
|
|
|
del rank0_dict, rank1_dict
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
print(f"Saving merged model to {output_path}...")
|
|
|
|
|
|
|
|
|
final_dtypes = {}
|
|
|
for key, tensor in merged_state_dict.items():
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
|
final_dtypes[tensor.dtype] = final_dtypes.get(tensor.dtype, 0) + 1
|
|
|
|
|
|
print(f"๐ Final merged model dtype distribution:")
|
|
|
for dtype, count in final_dtypes.items():
|
|
|
print(f" {dtype}: {count} tensors")
|
|
|
|
|
|
save_file(merged_state_dict, output_path)
|
|
|
print("โ
Successfully saved merged model!")
|
|
|
|
|
|
return merged_state_dict
|
|
|
|
|
|
def merge_with_custom_concatenation(rank0_path, rank1_path, output_path, concat_rules=None):
|
|
|
"""
|
|
|
์ฌ์ฉ์ ์ ์ ์ฐ๊ฒฐ ๊ท์น์ผ๋ก ๋ณํฉ
|
|
|
|
|
|
Args:
|
|
|
concat_rules (dict): ํค๋ณ ์ฐ๊ฒฐ ์ฐจ์ ์ง์ {'key_pattern': dim}
|
|
|
"""
|
|
|
if concat_rules is None:
|
|
|
|
|
|
concat_rules = {
|
|
|
'weight': 0,
|
|
|
'bias': 0,
|
|
|
}
|
|
|
|
|
|
print("Loading checkpoints...")
|
|
|
rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False)
|
|
|
rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False)
|
|
|
|
|
|
merged_state_dict = OrderedDict()
|
|
|
all_keys = set(rank0_dict.keys()) | set(rank1_dict.keys())
|
|
|
|
|
|
for key in sorted(all_keys):
|
|
|
rank0_tensor = rank0_dict.get(key)
|
|
|
rank1_tensor = rank1_dict.get(key)
|
|
|
|
|
|
if rank0_tensor is not None and rank1_tensor is not None:
|
|
|
|
|
|
concat_dim = 0
|
|
|
for pattern, dim in concat_rules.items():
|
|
|
if pattern in key:
|
|
|
concat_dim = dim
|
|
|
break
|
|
|
|
|
|
print(f"Merging {key} along dimension {concat_dim}")
|
|
|
merged_tensor = torch.cat([rank0_tensor, rank1_tensor], dim=concat_dim)
|
|
|
merged_state_dict[key] = merged_tensor
|
|
|
|
|
|
elif rank0_tensor is not None:
|
|
|
merged_state_dict[key] = rank0_tensor
|
|
|
elif rank1_tensor is not None:
|
|
|
merged_state_dict[key] = rank1_tensor
|
|
|
|
|
|
|
|
|
del rank0_dict, rank1_dict
|
|
|
gc.collect()
|
|
|
|
|
|
print(f"Saving to {output_path}...")
|
|
|
save_file(merged_state_dict, output_path)
|
|
|
print("โ
Merge completed!")
|
|
|
|
|
|
def comprehensive_verification(rank0_path, rank1_path, merged_path):
|
|
|
"""๋ณํฉ์ด ์ฌ๋ฐ๋ฅด๊ฒ ๋์๋์ง ์ข
ํฉ์ ์ผ๋ก ๊ฒ์ฆ"""
|
|
|
import torch.distributed.tensor
|
|
|
from safetensors import safe_open
|
|
|
|
|
|
print("๐ Starting comprehensive verification...")
|
|
|
|
|
|
|
|
|
print("\n๐ Loading original files...")
|
|
|
rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False)
|
|
|
rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False)
|
|
|
|
|
|
|
|
|
def convert_dtensor_to_tensor(state_dict):
|
|
|
converted_dict = {}
|
|
|
for key, value in state_dict.items():
|
|
|
if hasattr(value, '_local_tensor'):
|
|
|
converted_dict[key] = value._local_tensor
|
|
|
elif isinstance(value, torch.Tensor):
|
|
|
converted_dict[key] = value
|
|
|
else:
|
|
|
converted_dict[key] = value
|
|
|
return converted_dict
|
|
|
|
|
|
rank0_dict = convert_dtensor_to_tensor(rank0_dict)
|
|
|
rank1_dict = convert_dtensor_to_tensor(rank1_dict)
|
|
|
|
|
|
|
|
|
rank0_keys = set(rank0_dict.keys())
|
|
|
rank1_keys = set(rank1_dict.keys())
|
|
|
all_original_keys = rank0_keys | rank1_keys
|
|
|
shared_keys = rank0_keys & rank1_keys
|
|
|
rank0_only = rank0_keys - rank1_keys
|
|
|
rank1_only = rank1_keys - rank0_keys
|
|
|
|
|
|
print(f"๐ Original files analysis:")
|
|
|
print(f" Rank 0 keys: {len(rank0_keys)}")
|
|
|
print(f" Rank 1 keys: {len(rank1_keys)}")
|
|
|
print(f" Shared keys: {len(shared_keys)}")
|
|
|
print(f" Rank 0 only: {len(rank0_only)}")
|
|
|
print(f" Rank 1 only: {len(rank1_only)}")
|
|
|
print(f" Total unique keys: {len(all_original_keys)}")
|
|
|
|
|
|
|
|
|
original_params = 0
|
|
|
original_shapes = {}
|
|
|
|
|
|
for key in all_original_keys:
|
|
|
if key in shared_keys:
|
|
|
|
|
|
r0_tensor = rank0_dict[key]
|
|
|
r1_tensor = rank1_dict[key]
|
|
|
combined_shape = list(r0_tensor.shape)
|
|
|
combined_shape[0] += r1_tensor.shape[0]
|
|
|
original_shapes[key] = tuple(combined_shape)
|
|
|
original_params += r0_tensor.numel() + r1_tensor.numel()
|
|
|
elif key in rank0_only:
|
|
|
original_shapes[key] = rank0_dict[key].shape
|
|
|
original_params += rank0_dict[key].numel()
|
|
|
elif key in rank1_only:
|
|
|
original_shapes[key] = rank1_dict[key].shape
|
|
|
original_params += rank1_dict[key].numel()
|
|
|
|
|
|
print(f" Original total parameters: {original_params:,}")
|
|
|
|
|
|
|
|
|
print(f"\n๐ Loading merged file: {merged_path}")
|
|
|
merged_params = 0
|
|
|
merged_keys = set()
|
|
|
merged_shapes = {}
|
|
|
|
|
|
with safe_open(merged_path, framework="pt", device="cpu") as f:
|
|
|
merged_keys = set(f.keys())
|
|
|
for key in f.keys():
|
|
|
tensor = f.get_tensor(key)
|
|
|
merged_shapes[key] = tensor.shape
|
|
|
merged_params += tensor.numel()
|
|
|
|
|
|
print(f"๐ Merged file analysis:")
|
|
|
print(f" Merged keys: {len(merged_keys)}")
|
|
|
print(f" Merged parameters: {merged_params:,}")
|
|
|
|
|
|
|
|
|
print(f"\nโ
Verification Results:")
|
|
|
|
|
|
|
|
|
keys_match = len(merged_keys) == len(all_original_keys)
|
|
|
print(f" Keys count match: {keys_match} ({len(merged_keys)} vs {len(all_original_keys)})")
|
|
|
|
|
|
|
|
|
params_match = merged_params == original_params
|
|
|
print(f" Parameter count match: {params_match} ({merged_params:,} vs {original_params:,})")
|
|
|
|
|
|
|
|
|
missing_keys = all_original_keys - merged_keys
|
|
|
extra_keys = merged_keys - all_original_keys
|
|
|
|
|
|
if missing_keys:
|
|
|
print(f" โ Missing keys: {missing_keys}")
|
|
|
|
|
|
if extra_keys:
|
|
|
print(f" โ Extra keys: {extra_keys}")
|
|
|
|
|
|
|
|
|
shape_mismatches = []
|
|
|
for key in merged_keys & all_original_keys:
|
|
|
if merged_shapes[key] != original_shapes[key]:
|
|
|
shape_mismatches.append((key, merged_shapes[key], original_shapes[key]))
|
|
|
|
|
|
if shape_mismatches:
|
|
|
print(f" โ Shape mismatches:")
|
|
|
for key, merged_shape, original_shape in shape_mismatches[:5]:
|
|
|
print(f" {key}: {merged_shape} vs {original_shape}")
|
|
|
if len(shape_mismatches) > 5:
|
|
|
print(f" ... and {len(shape_mismatches) - 5} more")
|
|
|
|
|
|
|
|
|
print(f"\n๐ Detailed Analysis:")
|
|
|
print(f" Shared keys that should be concatenated:")
|
|
|
for key in sorted(list(shared_keys))[:10]:
|
|
|
r0_shape = rank0_dict[key].shape
|
|
|
r1_shape = rank1_dict[key].shape
|
|
|
expected_shape = list(r0_shape)
|
|
|
expected_shape[0] += r1_shape[0]
|
|
|
actual_shape = merged_shapes.get(key, "MISSING")
|
|
|
status = "โ
" if tuple(expected_shape) == actual_shape else "โ"
|
|
|
print(f" {status} {key}: {r0_shape} + {r1_shape} -> {actual_shape}")
|
|
|
|
|
|
if len(shared_keys) > 10:
|
|
|
print(f" ... and {len(shared_keys) - 10} more shared keys")
|
|
|
|
|
|
|
|
|
overall_success = keys_match and params_match and not missing_keys and not extra_keys and not shape_mismatches
|
|
|
|
|
|
print(f"\n{'='*50}")
|
|
|
if overall_success:
|
|
|
print("๐ MERGE VERIFICATION SUCCESSFUL!")
|
|
|
print(" All parameters have been correctly merged.")
|
|
|
else:
|
|
|
print("โ ๏ธ MERGE VERIFICATION FOUND ISSUES!")
|
|
|
print(" Please review the mismatches above.")
|
|
|
print(f"{'='*50}")
|
|
|
|
|
|
|
|
|
del rank0_dict, rank1_dict
|
|
|
gc.collect()
|
|
|
|
|
|
return overall_success
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
rank0_file = "model_rank_0.pt"
|
|
|
rank1_file = "model_rank_1.pt"
|
|
|
output_file = "merged_model.safetensors"
|
|
|
|
|
|
|
|
|
target_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
print("Starting merge process...")
|
|
|
merged_dict = merge_fsdp_to_safetensors(rank0_file, rank1_file, output_file, target_dtype)
|
|
|
|
|
|
|
|
|
print("\nStarting comprehensive verification...")
|
|
|
verification_passed = comprehensive_verification(rank0_file, rank1_file, output_file)
|
|
|
|
|
|
if verification_passed:
|
|
|
print(f"\n๐ Successfully merged and verified FSDP model to {output_file}")
|
|
|
else:
|
|
|
print(f"\nโ ๏ธ Merge completed but verification found issues. Please review the output above.")
|
|
|
|
|
|
|
|
|
print(f"\n๐ Testing if merged model can be loaded...")
|
|
|
try:
|
|
|
from safetensors import safe_open
|
|
|
with safe_open(output_file, framework="pt", device="cpu") as f:
|
|
|
sample_keys = list(f.keys())[:3]
|
|
|
for key in sample_keys:
|
|
|
tensor = f.get_tensor(key)
|
|
|
print(f" โ
Successfully loaded {key}: {tensor.shape}, dtype: {tensor.dtype}")
|
|
|
print(" โ
Merged model loads correctly!")
|
|
|
except Exception as e:
|
|
|
print(f" โ Error loading merged model: {e}")
|
|
|
|
|
|
print(f"\n๐ก Tip: To change dtype, modify 'target_dtype' in the script:")
|
|
|
print(f" - torch.float16 for fp16 (smaller file, less precision)")
|
|
|
print(f" - torch.bfloat16 for bf16 (good balance)")
|
|
|
print(f" - torch.float32 for fp32 (larger file, best precision)")
|
|
|
print(f" - None to keep original dtypes")
|
|
|
|