import torch import torch.distributed.tensor from safetensors.torch import save_file import os from collections import OrderedDict import gc def merge_fsdp_to_safetensors(rank0_path, rank1_path, output_path, target_dtype=None): """ FSDP로 분할된 두 개의 .pt 파일을 하나의 .safetensors 파일로 병합 Args: rank0_path (str): rank 0 .pt 파일 경로 rank1_path (str): rank 1 .pt 파일 경로 output_path (str): 출력할 .safetensors 파일 경로 target_dtype (torch.dtype, optional): 타겟 dtype (예: torch.float16, torch.bfloat16) """ print("Loading rank 0 checkpoint...") rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False) print("Loading rank 1 checkpoint...") rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False) # DTensor를 일반 텐서로 변환하는 함수 def convert_dtensor_to_tensor(state_dict): converted_dict = OrderedDict() dtype_info = {} for key, value in state_dict.items(): if hasattr(value, '_local_tensor'): # DTensor인 경우 로컬 텐서 추출 tensor = value._local_tensor converted_dict[key] = tensor dtype_info[key] = tensor.dtype print(f"Converted DTensor to tensor: {key} (dtype: {tensor.dtype})") elif isinstance(value, torch.Tensor): converted_dict[key] = value dtype_info[key] = value.dtype else: # 다른 타입은 그대로 유지 converted_dict[key] = value dtype_info[key] = type(value).__name__ return converted_dict, dtype_info print("Converting DTensors to regular tensors...") rank0_dict, rank0_dtypes = convert_dtensor_to_tensor(rank0_dict) rank1_dict, rank1_dtypes = convert_dtensor_to_tensor(rank1_dict) # dtype 정보 출력 print("\n📋 Original dtype information:") all_dtypes_r0 = set(dtype_info for dtype_info in rank0_dtypes.values() if isinstance(dtype_info, torch.dtype)) all_dtypes_r1 = set(dtype_info for dtype_info in rank1_dtypes.values() if isinstance(dtype_info, torch.dtype)) all_dtypes = all_dtypes_r0 | all_dtypes_r1 print(f" Rank 0 dtypes found: {all_dtypes_r0}") print(f" Rank 1 dtypes found: {all_dtypes_r1}") print(f" All dtypes: {all_dtypes}") if target_dtype: print(f" Target dtype specified: {target_dtype}") else: print(" No target dtype specified - keeping original dtypes") # 병합된 상태 사전 merged_state_dict = OrderedDict() # rank 0의 모든 키들을 먼저 처리 all_keys = set(rank0_dict.keys()) | set(rank1_dict.keys()) print(f"Total unique keys found: {len(all_keys)}") for key in sorted(all_keys): rank0_tensor = rank0_dict.get(key) rank1_tensor = rank1_dict.get(key) if rank0_tensor is not None and rank1_tensor is not None: # 두 rank에 모두 존재하는 경우 - 차원 확인 후 연결 print(f"Merging key: {key}") print(f" Rank 0 shape: {rank0_tensor.shape}, dtype: {rank0_tensor.dtype}") print(f" Rank 1 shape: {rank1_tensor.shape}, dtype: {rank1_tensor.dtype}") # dtype 변환 (필요한 경우) if target_dtype and rank0_tensor.dtype != target_dtype: rank0_tensor = rank0_tensor.to(target_dtype) print(f" Converted rank 0 to {target_dtype}") if target_dtype and rank1_tensor.dtype != target_dtype: rank1_tensor = rank1_tensor.to(target_dtype) print(f" Converted rank 1 to {target_dtype}") # 첫 번째 차원으로 연결 (일반적인 FSDP 샤딩 방식) merged_tensor = torch.cat([rank0_tensor, rank1_tensor], dim=0) merged_state_dict[key] = merged_tensor print(f" Merged shape: {merged_tensor.shape}, dtype: {merged_tensor.dtype}") elif rank0_tensor is not None: # rank 0에만 존재 tensor = rank0_tensor if target_dtype and isinstance(tensor, torch.Tensor) and tensor.dtype != target_dtype: tensor = tensor.to(target_dtype) print(f"Converting {key} from rank 0: {rank0_tensor.dtype} -> {target_dtype}") print(f"Adding from rank 0: {key} (shape: {tensor.shape if isinstance(tensor, torch.Tensor) else 'N/A'}, dtype: {tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor).__name__})") merged_state_dict[key] = tensor elif rank1_tensor is not None: # rank 1에만 존재 tensor = rank1_tensor if target_dtype and isinstance(tensor, torch.Tensor) and tensor.dtype != target_dtype: tensor = tensor.to(target_dtype) print(f"Converting {key} from rank 1: {rank1_tensor.dtype} -> {target_dtype}") print(f"Adding from rank 1: {key} (shape: {tensor.shape if isinstance(tensor, torch.Tensor) else 'N/A'}, dtype: {tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor).__name__})") merged_state_dict[key] = tensor print(f"\nTotal merged parameters: {len(merged_state_dict)}") # 메모리 정리 del rank0_dict, rank1_dict gc.collect() # safetensors로 저장 print(f"Saving merged model to {output_path}...") # 최종 dtype 정보 출력 final_dtypes = {} for key, tensor in merged_state_dict.items(): if isinstance(tensor, torch.Tensor): final_dtypes[tensor.dtype] = final_dtypes.get(tensor.dtype, 0) + 1 print(f"📋 Final merged model dtype distribution:") for dtype, count in final_dtypes.items(): print(f" {dtype}: {count} tensors") save_file(merged_state_dict, output_path) print("✅ Successfully saved merged model!") return merged_state_dict def merge_with_custom_concatenation(rank0_path, rank1_path, output_path, concat_rules=None): """ 사용자 정의 연결 규칙으로 병합 Args: concat_rules (dict): 키별 연결 차원 지정 {'key_pattern': dim} """ if concat_rules is None: # 기본 규칙 concat_rules = { 'weight': 0, # 가중치는 첫 번째 차원으로 연결 'bias': 0, # 편향도 첫 번째 차원으로 연결 } print("Loading checkpoints...") rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False) rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False) merged_state_dict = OrderedDict() all_keys = set(rank0_dict.keys()) | set(rank1_dict.keys()) for key in sorted(all_keys): rank0_tensor = rank0_dict.get(key) rank1_tensor = rank1_dict.get(key) if rank0_tensor is not None and rank1_tensor is not None: # 연결 차원 결정 concat_dim = 0 # 기본값 for pattern, dim in concat_rules.items(): if pattern in key: concat_dim = dim break print(f"Merging {key} along dimension {concat_dim}") merged_tensor = torch.cat([rank0_tensor, rank1_tensor], dim=concat_dim) merged_state_dict[key] = merged_tensor elif rank0_tensor is not None: merged_state_dict[key] = rank0_tensor elif rank1_tensor is not None: merged_state_dict[key] = rank1_tensor # 정리 및 저장 del rank0_dict, rank1_dict gc.collect() print(f"Saving to {output_path}...") save_file(merged_state_dict, output_path) print("✅ Merge completed!") def comprehensive_verification(rank0_path, rank1_path, merged_path): """병합이 올바르게 되었는지 종합적으로 검증""" import torch.distributed.tensor from safetensors import safe_open print("🔍 Starting comprehensive verification...") # 1. 원본 파일들 로드 print("\n📁 Loading original files...") rank0_dict = torch.load(rank0_path, map_location='cpu', weights_only=False) rank1_dict = torch.load(rank1_path, map_location='cpu', weights_only=False) # DTensor를 일반 텐서로 변환 def convert_dtensor_to_tensor(state_dict): converted_dict = {} for key, value in state_dict.items(): if hasattr(value, '_local_tensor'): converted_dict[key] = value._local_tensor elif isinstance(value, torch.Tensor): converted_dict[key] = value else: converted_dict[key] = value return converted_dict rank0_dict = convert_dtensor_to_tensor(rank0_dict) rank1_dict = convert_dtensor_to_tensor(rank1_dict) # 2. 원본 파일들 분석 rank0_keys = set(rank0_dict.keys()) rank1_keys = set(rank1_dict.keys()) all_original_keys = rank0_keys | rank1_keys shared_keys = rank0_keys & rank1_keys rank0_only = rank0_keys - rank1_keys rank1_only = rank1_keys - rank0_keys print(f"📊 Original files analysis:") print(f" Rank 0 keys: {len(rank0_keys)}") print(f" Rank 1 keys: {len(rank1_keys)}") print(f" Shared keys: {len(shared_keys)}") print(f" Rank 0 only: {len(rank0_only)}") print(f" Rank 1 only: {len(rank1_only)}") print(f" Total unique keys: {len(all_original_keys)}") # 3. 원본 파라미터 수 계산 original_params = 0 original_shapes = {} for key in all_original_keys: if key in shared_keys: # 공유 키는 두 텐서를 연결한 크기로 계산 r0_tensor = rank0_dict[key] r1_tensor = rank1_dict[key] combined_shape = list(r0_tensor.shape) combined_shape[0] += r1_tensor.shape[0] # 첫 번째 차원으로 연결 가정 original_shapes[key] = tuple(combined_shape) original_params += r0_tensor.numel() + r1_tensor.numel() elif key in rank0_only: original_shapes[key] = rank0_dict[key].shape original_params += rank0_dict[key].numel() elif key in rank1_only: original_shapes[key] = rank1_dict[key].shape original_params += rank1_dict[key].numel() print(f" Original total parameters: {original_params:,}") # 4. 병합된 파일 분석 print(f"\n📁 Loading merged file: {merged_path}") merged_params = 0 merged_keys = set() merged_shapes = {} with safe_open(merged_path, framework="pt", device="cpu") as f: merged_keys = set(f.keys()) for key in f.keys(): tensor = f.get_tensor(key) merged_shapes[key] = tensor.shape merged_params += tensor.numel() print(f"📊 Merged file analysis:") print(f" Merged keys: {len(merged_keys)}") print(f" Merged parameters: {merged_params:,}") # 5. 비교 및 검증 print(f"\n✅ Verification Results:") # 키 개수 비교 keys_match = len(merged_keys) == len(all_original_keys) print(f" Keys count match: {keys_match} ({len(merged_keys)} vs {len(all_original_keys)})") # 파라미터 수 비교 params_match = merged_params == original_params print(f" Parameter count match: {params_match} ({merged_params:,} vs {original_params:,})") # 키 이름 비교 missing_keys = all_original_keys - merged_keys extra_keys = merged_keys - all_original_keys if missing_keys: print(f" ❌ Missing keys: {missing_keys}") if extra_keys: print(f" ❌ Extra keys: {extra_keys}") # 개별 텐서 크기 비교 shape_mismatches = [] for key in merged_keys & all_original_keys: if merged_shapes[key] != original_shapes[key]: shape_mismatches.append((key, merged_shapes[key], original_shapes[key])) if shape_mismatches: print(f" ❌ Shape mismatches:") for key, merged_shape, original_shape in shape_mismatches[:5]: # 처음 5개만 표시 print(f" {key}: {merged_shape} vs {original_shape}") if len(shape_mismatches) > 5: print(f" ... and {len(shape_mismatches) - 5} more") # 6. 세부 분석 (선택적) print(f"\n📋 Detailed Analysis:") print(f" Shared keys that should be concatenated:") for key in sorted(list(shared_keys))[:10]: # 처음 10개만 표시 r0_shape = rank0_dict[key].shape r1_shape = rank1_dict[key].shape expected_shape = list(r0_shape) expected_shape[0] += r1_shape[0] actual_shape = merged_shapes.get(key, "MISSING") status = "✅" if tuple(expected_shape) == actual_shape else "❌" print(f" {status} {key}: {r0_shape} + {r1_shape} -> {actual_shape}") if len(shared_keys) > 10: print(f" ... and {len(shared_keys) - 10} more shared keys") # 7. 최종 결과 overall_success = keys_match and params_match and not missing_keys and not extra_keys and not shape_mismatches print(f"\n{'='*50}") if overall_success: print("🎉 MERGE VERIFICATION SUCCESSFUL!") print(" All parameters have been correctly merged.") else: print("⚠️ MERGE VERIFICATION FOUND ISSUES!") print(" Please review the mismatches above.") print(f"{'='*50}") # 정리 del rank0_dict, rank1_dict gc.collect() return overall_success # 사용 예시 if __name__ == "__main__": # 파일 경로 설정 rank0_file = "model_rank_0.pt" # 실제 파일명으로 변경 rank1_file = "model_rank_1.pt" # 실제 파일명으로 변경 output_file = "merged_model.safetensors" # dtype 옵션 설정 target_dtype = torch.bfloat16 # bf16으로 변환 # 기본 병합 print("Starting merge process...") merged_dict = merge_fsdp_to_safetensors(rank0_file, rank1_file, output_file, target_dtype) # 종합적인 검증 print("\nStarting comprehensive verification...") verification_passed = comprehensive_verification(rank0_file, rank1_file, output_file) if verification_passed: print(f"\n🎉 Successfully merged and verified FSDP model to {output_file}") else: print(f"\n⚠️ Merge completed but verification found issues. Please review the output above.") # 추가: 간단한 로드 테스트 print(f"\n🔍 Testing if merged model can be loaded...") try: from safetensors import safe_open with safe_open(output_file, framework="pt", device="cpu") as f: sample_keys = list(f.keys())[:3] for key in sample_keys: tensor = f.get_tensor(key) print(f" ✅ Successfully loaded {key}: {tensor.shape}, dtype: {tensor.dtype}") print(" ✅ Merged model loads correctly!") except Exception as e: print(f" ❌ Error loading merged model: {e}") print(f"\n💡 Tip: To change dtype, modify 'target_dtype' in the script:") print(f" - torch.float16 for fp16 (smaller file, less precision)") print(f" - torch.bfloat16 for bf16 (good balance)") print(f" - torch.float32 for fp32 (larger file, best precision)") print(f" - None to keep original dtypes")