File size: 10,782 Bytes
8c785f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# pip install pathlib safetensors tqdm 

import json
import os
from pathlib import Path
from safetensors.torch import load_file, save_file, safe_open
from collections import defaultdict
import torch # Needed for tensor manipulation if any dtype/device casting were required (not expected here)
import shutil
from tqdm import tqdm # Optional: for progress bar

# --- Configuration ---
BASE_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/models/Mistral-Small-3.2-24B-Instruct-2506")
TRAINED_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/axolotl/24B-Retrain/merged")
OUTPUT_MODEL_DIR = Path("/home/dgxuser/workspace/docshotgun/models/MS3.2-Venice-SFT-KTO-0.35-beta-re-vision")

# Define the prefix used in the base model for language model layers
BASE_LM_PREFIX = "language_model."
# Define the prefix used in the trained model for language model layers
# (Assuming the trained model has the prefix stripped)
TRAINED_LM_PREFIX = "" # If trained keys are 'model.layers...', this is effectively empty relative to the base

# --- Safety Check ---
if OUTPUT_MODEL_DIR.exists() and any(OUTPUT_MODEL_DIR.iterdir()):
    print(f"Warning: Output directory {OUTPUT_MODEL_DIR} already exists and is not empty.")
    # Decide if you want to overwrite or stop
    # input("Press Enter to continue and potentially overwrite files, or Ctrl+C to abort.")
    pass # Or raise an error: raise FileExistsError(f"Output directory {OUTPUT_MODEL_DIR} is not empty.")

# --- Create Output Directory ---
OUTPUT_MODEL_DIR.mkdir(parents=True, exist_ok=True)

# --- Load Index Files ---
try:
    base_index_path = next(BASE_MODEL_DIR.glob("*.safetensors.index.json"))
    with open(base_index_path, 'r') as f:
        base_index = json.load(f)
    print(f"Loaded base model index from: {base_index_path}")
except StopIteration:
    raise FileNotFoundError(f"Could not find *.safetensors.index.json in {BASE_MODEL_DIR}")

try:
    trained_index_path = next(TRAINED_MODEL_DIR.glob("*.safetensors.index.json"))
    with open(trained_index_path, 'r') as f:
        trained_index = json.load(f)
    print(f"Loaded trained model index from: {trained_index_path}")
except StopIteration:
    raise FileNotFoundError(f"Could not find *.safetensors.index.json in {TRAINED_MODEL_DIR}")


# --- Prepare Trained Tensor Lookup ---
# Create a map from trained tensor name to the shard file it's in
trained_tensor_to_shard = trained_index.get("weight_map", {})
if not trained_tensor_to_shard:
     raise ValueError("Could not find 'weight_map' in trained model index.")
print(f"Built lookup map for {len(trained_tensor_to_shard)} trained tensors.")

# --- Process Shards ---
base_weight_map = base_index.get("weight_map", {})
if not base_weight_map:
     raise ValueError("Could not find 'weight_map' in base model index.")

# Group base tensors by the shard they belong to
base_shards_content = defaultdict(list)
for tensor_name, shard_file in base_weight_map.items():
    base_shards_content[shard_file].append(tensor_name)

print(f"Processing {len(base_shards_content)} shards from the base model...")

# Use tqdm for progress bar over shards
for shard_file, tensors_in_shard in tqdm(base_shards_content.items(), desc="Merging Shards"):
    base_shard_path = BASE_MODEL_DIR / shard_file
    output_shard_path = OUTPUT_MODEL_DIR / shard_file

    # Load the current base model shard
    # print(f"  Loading base shard: {shard_file}")
    current_shard_tensors = load_file(base_shard_path, device="cpu") # Load to CPU to save GPU memory

    # Identify which tensors in this shard need replacement
    tensors_to_replace = {} # {base_tensor_name: trained_tensor_name}
    for base_tensor_name in tensors_in_shard:
        if base_tensor_name.startswith(BASE_LM_PREFIX):
            # Derive the corresponding name in the trained model
            # e.g., language_model.model.layers.0... -> model.layers.0...
            potential_trained_name = base_tensor_name[len(BASE_LM_PREFIX):]

            # Check if this derived name exists in the trained model's index
            if potential_trained_name in trained_tensor_to_shard:
                tensors_to_replace[base_tensor_name] = potential_trained_name
            else:
                 # This might happen for non-layer LM parts if the naming convention differs
                 # Or if the base model has LM parts not present in the stripped trained model
                 # print(f"    Debug: Base tensor {base_tensor_name} starts with prefix, but derived name {potential_trained_name} not found in trained model map. Skipping replacement.")
                 pass # Keep the base tensor

        # --- Explicit Check for LM Head (Common Case) ---
        # Many models have `lm_head.weight` outside the `language_model` block
        # Check if the trained model also has `lm_head.weight` (or similar)
        elif base_tensor_name == "lm_head.weight": # Adjust if your LM head has a different name
            if "lm_head.weight" in trained_tensor_to_shard:
                 tensors_to_replace[base_tensor_name] = "lm_head.weight"
            else:
                # print(f"    Debug: Base tensor 'lm_head.weight' found, but not present in trained model map. Skipping replacement.")
                pass # Keep the base tensor

    # Group the needed trained tensors by the shard they are located in
    needed_trained_shards = defaultdict(list) # {trained_shard_file: [list of trained_tensor_names]}
    for base_name, trained_name in tensors_to_replace.items():
        try:
            trained_shard_file = trained_tensor_to_shard[trained_name]
            needed_trained_shards[trained_shard_file].append(trained_name)
        except KeyError:
            print(f"    Warning: Tensor '{trained_name}' (derived from '{base_name}') listed for replacement but not found in trained model's weight map. Skipping.")
            # Remove from replacement list if lookup fails
            del tensors_to_replace[base_name]


    # Load needed trained shards one by one and perform replacements
    loaded_trained_tensors = {}
    for trained_shard_file, trained_tensor_names in needed_trained_shards.items():
        trained_shard_path = TRAINED_MODEL_DIR / trained_shard_file
        # print(f"    Loading trained shard: {trained_shard_file} for {len(trained_tensor_names)} tensor(s)")
        try:
            # Load only the required tensors from the trained shard if possible (optimisation - requires safetensors >= 0.4.0)
            # Note: As of mid-2023, load_file loads the whole shard. This is aspirational or requires custom loading.
            # For now, we load the whole shard.
            shard_data = load_file(trained_shard_path, device="cpu")
            for name in trained_tensor_names:
                 if name in shard_data:
                     loaded_trained_tensors[name] = shard_data[name]
                 else:
                     print(f"      Warning: Expected tensor '{name}' not found within loaded trained shard '{trained_shard_file}'.")
            del shard_data # Free memory
        except FileNotFoundError:
             print(f"    Error: Could not find required trained shard file: {trained_shard_path}. Cannot perform replacements for tensors in this shard.")
             # Remove base tensors that relied on this missing shard from replacement list
             base_names_to_remove = [b_name for b_name, t_name in tensors_to_replace.items() if t_name in trained_tensor_names]
             for b_name in base_names_to_remove:
                 del tensors_to_replace[b_name]
                 print(f"      Skipping replacement for base tensor: {b_name}")


    # Perform the replacements in the loaded base shard dictionary
    replacement_count = 0
    for base_name, trained_name in tensors_to_replace.items():
        if trained_name in loaded_trained_tensors:
            # Sanity check shapes (optional but recommended)
            if current_shard_tensors[base_name].shape != loaded_trained_tensors[trained_name].shape:
                 print(f"    Warning: Shape mismatch for {base_name}! Base: {current_shard_tensors[base_name].shape}, Trained: {loaded_trained_tensors[trained_name].shape}. Skipping replacement.")
                 continue
            current_shard_tensors[base_name] = loaded_trained_tensors[trained_name]
            replacement_count += 1
        # else: # Already handled by warnings above
        #    print(f"    Warning: Trained tensor '{trained_name}' was expected but not loaded. Skipping replacement for '{base_name}'.")

    # print(f"    Replaced {replacement_count} tensors in shard {shard_file}.")

    # Save the modified shard to the output directory
    # Ensure the directory for the shard exists if shards are nested (unlikely but possible)
    output_shard_path.parent.mkdir(parents=True, exist_ok=True)
    # print(f"  Saving modified shard to: {output_shard_path}")
    # Metadata can be copied if needed, but usually not necessary for simple weight replacement
    # Pass existing metadata from base_index if available and relevant per-tensor
    save_file(current_shard_tensors, output_shard_path)

    # Clean up loaded tensors for this shard
    del current_shard_tensors
    del loaded_trained_tensors

print("Finished processing shards.")

# --- Copy Non-Tensor Files ---
print("Copying non-tensor files (index, config, tokenizer, etc.)...")
copied_files = []
skipped_files = []

for item in BASE_MODEL_DIR.iterdir():
    # Skip the actual shard files and the index we processed
    if item.is_file() and (".safetensors" not in item.name) and (".md" not in item.name):
         output_path = OUTPUT_MODEL_DIR / item.name
         try:
            shutil.copy2(item, output_path) # copy2 preserves metadata
            copied_files.append(item.name)
         except Exception as e:
             skipped_files.append(f"{item.name} (Error: {e})")
    elif item.is_dir(): # Also copy relevant subdirectories like tokenizer configs
         output_path = OUTPUT_MODEL_DIR / item.name
         if output_path.exists():
             shutil.rmtree(output_path) # Overwrite directory if exists
         try:
             shutil.copytree(item, output_path)
             copied_files.append(f"{item.name}/")
         except Exception as e:
            skipped_files.append(f"{item.name}/ (Error: {e})")

# Specifically copy the original base index file to the new directory
try:
    shutil.copy2(base_index_path, OUTPUT_MODEL_DIR / base_index_path.name)
    copied_files.append(base_index_path.name)
except Exception as e:
    skipped_files.append(f"{base_index_path.name} (Error: {e})")


print(f"Copied: {', '.join(copied_files)}")
if skipped_files:
    print(f"Skipped/Errors: {', '.join(skipped_files)}")


print(f"\nSuccessfully created merged model in: {OUTPUT_MODEL_DIR}")