| | |
| |
|
| | import json |
| | import os |
| | from pathlib import Path |
| | from safetensors.torch import load_file, save_file, safe_open |
| | from collections import defaultdict |
| | import torch |
| | import shutil |
| | from tqdm import tqdm |
| |
|
| | |
| | BASE_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/models/Mistral-Small-3.2-24B-Instruct-2506") |
| | TRAINED_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/axolotl/24B-Retrain/merged") |
| | OUTPUT_MODEL_DIR = Path("/home/dgxuser/workspace/docshotgun/models/MS3.2-Venice-SFT-KTO-0.35-beta-re-vision") |
| |
|
| | |
| | BASE_LM_PREFIX = "language_model." |
| | |
| | |
| | TRAINED_LM_PREFIX = "" |
| |
|
| | |
| | if OUTPUT_MODEL_DIR.exists() and any(OUTPUT_MODEL_DIR.iterdir()): |
| | print(f"Warning: Output directory {OUTPUT_MODEL_DIR} already exists and is not empty.") |
| | |
| | |
| | pass |
| |
|
| | |
| | OUTPUT_MODEL_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | try: |
| | base_index_path = next(BASE_MODEL_DIR.glob("*.safetensors.index.json")) |
| | with open(base_index_path, 'r') as f: |
| | base_index = json.load(f) |
| | print(f"Loaded base model index from: {base_index_path}") |
| | except StopIteration: |
| | raise FileNotFoundError(f"Could not find *.safetensors.index.json in {BASE_MODEL_DIR}") |
| |
|
| | try: |
| | trained_index_path = next(TRAINED_MODEL_DIR.glob("*.safetensors.index.json")) |
| | with open(trained_index_path, 'r') as f: |
| | trained_index = json.load(f) |
| | print(f"Loaded trained model index from: {trained_index_path}") |
| | except StopIteration: |
| | raise FileNotFoundError(f"Could not find *.safetensors.index.json in {TRAINED_MODEL_DIR}") |
| |
|
| |
|
| | |
| | |
| | trained_tensor_to_shard = trained_index.get("weight_map", {}) |
| | if not trained_tensor_to_shard: |
| | raise ValueError("Could not find 'weight_map' in trained model index.") |
| | print(f"Built lookup map for {len(trained_tensor_to_shard)} trained tensors.") |
| |
|
| | |
| | base_weight_map = base_index.get("weight_map", {}) |
| | if not base_weight_map: |
| | raise ValueError("Could not find 'weight_map' in base model index.") |
| |
|
| | |
| | base_shards_content = defaultdict(list) |
| | for tensor_name, shard_file in base_weight_map.items(): |
| | base_shards_content[shard_file].append(tensor_name) |
| |
|
| | print(f"Processing {len(base_shards_content)} shards from the base model...") |
| |
|
| | |
| | for shard_file, tensors_in_shard in tqdm(base_shards_content.items(), desc="Merging Shards"): |
| | base_shard_path = BASE_MODEL_DIR / shard_file |
| | output_shard_path = OUTPUT_MODEL_DIR / shard_file |
| |
|
| | |
| | |
| | current_shard_tensors = load_file(base_shard_path, device="cpu") |
| |
|
| | |
| | tensors_to_replace = {} |
| | for base_tensor_name in tensors_in_shard: |
| | if base_tensor_name.startswith(BASE_LM_PREFIX): |
| | |
| | |
| | potential_trained_name = base_tensor_name[len(BASE_LM_PREFIX):] |
| |
|
| | |
| | if potential_trained_name in trained_tensor_to_shard: |
| | tensors_to_replace[base_tensor_name] = potential_trained_name |
| | else: |
| | |
| | |
| | |
| | pass |
| |
|
| | |
| | |
| | |
| | elif base_tensor_name == "lm_head.weight": |
| | if "lm_head.weight" in trained_tensor_to_shard: |
| | tensors_to_replace[base_tensor_name] = "lm_head.weight" |
| | else: |
| | |
| | pass |
| |
|
| | |
| | needed_trained_shards = defaultdict(list) |
| | for base_name, trained_name in tensors_to_replace.items(): |
| | try: |
| | trained_shard_file = trained_tensor_to_shard[trained_name] |
| | needed_trained_shards[trained_shard_file].append(trained_name) |
| | except KeyError: |
| | print(f" Warning: Tensor '{trained_name}' (derived from '{base_name}') listed for replacement but not found in trained model's weight map. Skipping.") |
| | |
| | del tensors_to_replace[base_name] |
| |
|
| |
|
| | |
| | loaded_trained_tensors = {} |
| | for trained_shard_file, trained_tensor_names in needed_trained_shards.items(): |
| | trained_shard_path = TRAINED_MODEL_DIR / trained_shard_file |
| | |
| | try: |
| | |
| | |
| | |
| | shard_data = load_file(trained_shard_path, device="cpu") |
| | for name in trained_tensor_names: |
| | if name in shard_data: |
| | loaded_trained_tensors[name] = shard_data[name] |
| | else: |
| | print(f" Warning: Expected tensor '{name}' not found within loaded trained shard '{trained_shard_file}'.") |
| | del shard_data |
| | except FileNotFoundError: |
| | print(f" Error: Could not find required trained shard file: {trained_shard_path}. Cannot perform replacements for tensors in this shard.") |
| | |
| | base_names_to_remove = [b_name for b_name, t_name in tensors_to_replace.items() if t_name in trained_tensor_names] |
| | for b_name in base_names_to_remove: |
| | del tensors_to_replace[b_name] |
| | print(f" Skipping replacement for base tensor: {b_name}") |
| |
|
| |
|
| | |
| | replacement_count = 0 |
| | for base_name, trained_name in tensors_to_replace.items(): |
| | if trained_name in loaded_trained_tensors: |
| | |
| | if current_shard_tensors[base_name].shape != loaded_trained_tensors[trained_name].shape: |
| | print(f" Warning: Shape mismatch for {base_name}! Base: {current_shard_tensors[base_name].shape}, Trained: {loaded_trained_tensors[trained_name].shape}. Skipping replacement.") |
| | continue |
| | current_shard_tensors[base_name] = loaded_trained_tensors[trained_name] |
| | replacement_count += 1 |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | output_shard_path.parent.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | |
| | save_file(current_shard_tensors, output_shard_path) |
| |
|
| | |
| | del current_shard_tensors |
| | del loaded_trained_tensors |
| |
|
| | print("Finished processing shards.") |
| |
|
| | |
| | print("Copying non-tensor files (index, config, tokenizer, etc.)...") |
| | copied_files = [] |
| | skipped_files = [] |
| |
|
| | for item in BASE_MODEL_DIR.iterdir(): |
| | |
| | if item.is_file() and (".safetensors" not in item.name) and (".md" not in item.name): |
| | output_path = OUTPUT_MODEL_DIR / item.name |
| | try: |
| | shutil.copy2(item, output_path) |
| | copied_files.append(item.name) |
| | except Exception as e: |
| | skipped_files.append(f"{item.name} (Error: {e})") |
| | elif item.is_dir(): |
| | output_path = OUTPUT_MODEL_DIR / item.name |
| | if output_path.exists(): |
| | shutil.rmtree(output_path) |
| | try: |
| | shutil.copytree(item, output_path) |
| | copied_files.append(f"{item.name}/") |
| | except Exception as e: |
| | skipped_files.append(f"{item.name}/ (Error: {e})") |
| |
|
| | |
| | try: |
| | shutil.copy2(base_index_path, OUTPUT_MODEL_DIR / base_index_path.name) |
| | copied_files.append(base_index_path.name) |
| | except Exception as e: |
| | skipped_files.append(f"{base_index_path.name} (Error: {e})") |
| |
|
| |
|
| | print(f"Copied: {', '.join(copied_files)}") |
| | if skipped_files: |
| | print(f"Skipped/Errors: {', '.join(skipped_files)}") |
| |
|
| |
|
| | print(f"\nSuccessfully created merged model in: {OUTPUT_MODEL_DIR}") |
| |
|