AlekseyCalvin commited on
Commit
e859d40
·
verified ·
1 Parent(s): f4dc6b6

Create merge_utils.py

Browse files
Files changed (1) hide show
  1. merge_utils.py +237 -0
merge_utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import gc
4
+ import torch
5
+ import shutil
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
10
+ import pydantic
11
+ from pydantic import ConfigDict, BaseModel
12
+ # This forces Pydantic v2 to accept torch.Tensor as a valid type globally
13
+ BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
14
+
15
+ try:
16
+ # Standard Merging
17
+ from mergekit.config import MergeConfiguration
18
+ from mergekit.merge import run_merge, MergeOptions
19
+
20
+ # MoE Merging
21
+ from mergekit.moe.config import MoEMergeConfig
22
+ from mergekit.scripts.moe import build as build_moe
23
+
24
+ # Raw PyTorch Merging
25
+ from mergekit.scripts.merge_raw_pytorch import RawPyTorchMergeConfig, plan_flat_merge
26
+ from mergekit.graph import Executor
27
+
28
+ except ImportError:
29
+ print("Warning: mergekit not installed. Please install it via requirements.txt")
30
+
31
+ def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
32
+ """
33
+ Executes a MergeKit run by intelligently detecting the config type.
34
+ """
35
+ # Force garbage collection before start
36
+ gc.collect()
37
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
38
+
39
+ # Shared Options
40
+ merge_opts = MergeOptions(
41
+ device=device,
42
+ copy_tokenizer=True,
43
+ lazy_unpickle=True,
44
+ low_cpu_memory=True,
45
+ max_shard_size=int(shard_gb * 1024**3),
46
+ allow_crimes=True # Allow loose constraints
47
+ )
48
+
49
+ # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) ---
50
+ if "experts" in config_dict:
51
+ print("🚀 Detected MoE Configuration.")
52
+ try:
53
+ # Validate using the specific MoE Schema
54
+ conf = MoEMergeConfig.model_validate(config_dict)
55
+
56
+ # Execute using the build function from mergekit.scripts.moe
57
+ build_moe(
58
+ config=conf,
59
+ out_path=out_path,
60
+ merge_options=merge_opts,
61
+ load_in_4bit=False,
62
+ load_in_8bit=False,
63
+ device=device,
64
+ verbose=True
65
+ )
66
+ print("✅ MoE Construction Complete.")
67
+
68
+ except Exception as e:
69
+ raise RuntimeError(f"MoE Build Failed: {e}")
70
+
71
+ # --- BRANCH 2: STANDARD MERGE (TIES, SLERP, ETC.) ---
72
+ else:
73
+ print("⚡ Detected Standard Merge Configuration.")
74
+ try:
75
+ # Validate using the Standard Schema
76
+ conf = MergeConfiguration.model_validate(config_dict)
77
+
78
+ # Execute using the standard runner
79
+ run_merge(
80
+ conf,
81
+ out_path=out_path,
82
+ device=device,
83
+ low_cpu_mem=True,
84
+ copy_tokenizer=True,
85
+ lazy_unpickle=True,
86
+ max_shard_size=int(shard_gb * 1024**3)
87
+ )
88
+ print("✅ Standard Merge Complete.")
89
+
90
+ except pydantic.ValidationError as e:
91
+ raise ValueError(f"Invalid Merge Configuration: {e}")
92
+ except Exception as e:
93
+ raise RuntimeError(f"Merge Failed: {e}")
94
+
95
+ gc.collect()
96
+
97
+ def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"):
98
+ """
99
+ Executes a Raw PyTorch merge for non-transformer models.
100
+ """
101
+ print("🧠 Executing Raw PyTorch Merge...")
102
+ try:
103
+ # Validate using Raw Schema
104
+ conf = RawPyTorchMergeConfig.model_validate(config_dict)
105
+
106
+ merge_opts = MergeOptions(
107
+ device=device,
108
+ low_cpu_memory=True,
109
+ out_shard_size=int(shard_gb * 1024**3),
110
+ lazy_unpickle=True,
111
+ safe_serialization=True
112
+ )
113
+
114
+ # Plan the merge tasks
115
+ tasks = plan_flat_merge(
116
+ conf,
117
+ out_path,
118
+ tensor_union=False,
119
+ tensor_intersection=False,
120
+ options=merge_opts
121
+ )
122
+
123
+ # Execute the graph
124
+ executor = Executor(
125
+ tasks,
126
+ math_device=device,
127
+ storage_device="cpu" # Force storage to CPU for low-resource safety
128
+ )
129
+ executor.execute()
130
+ print("✅ Raw PyTorch Merge Complete.")
131
+
132
+ except Exception as e:
133
+ raise RuntimeError(f"Raw Merge Failed: {e}")
134
+ finally:
135
+ gc.collect()
136
+
137
+ def build_full_merge_config(
138
+ method, models, base_model, weights, density,
139
+ dtype, tokenizer_source, layer_ranges
140
+ ):
141
+ """
142
+ Constructs the YAML dictionary for general merging (Linear, SLERP, TIES, etc.)
143
+ """
144
+ config = {
145
+ "merge_method": method.lower(),
146
+ "base_model": base_model if base_model else models[0],
147
+ "dtype": dtype,
148
+ "tokenizer_source": tokenizer_source,
149
+ "models": []
150
+ }
151
+
152
+ w_list = []
153
+ if weights:
154
+ try:
155
+ w_list = [float(x.strip()) for x in weights.split(',')]
156
+ except:
157
+ pass
158
+
159
+ for i, m in enumerate(models):
160
+ entry = {"model": m, "parameters": {}}
161
+
162
+ # Method Specific Param Injection
163
+ if method.lower() in ["ties", "dare_ties", "dare_linear"]:
164
+ entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
165
+ entry["parameters"]["density"] = density
166
+ elif method.lower() in ["slerp", "linear"]:
167
+ entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
168
+
169
+ config["models"].append(entry)
170
+
171
+ # Inject Slices/Layer Ranges if provided
172
+ if layer_ranges and layer_ranges.strip():
173
+ try:
174
+ extra_params = yaml.safe_load(layer_ranges)
175
+ if isinstance(extra_params, dict):
176
+ config.update(extra_params)
177
+ except Exception as e:
178
+ print(f"Error parsing layer ranges JSON: {e}")
179
+
180
+ return config
181
+
182
+ def build_moe_config(
183
+ base_model, experts, prompts, gate_mode, dtype,
184
+ tokenizer_source
185
+ ):
186
+ """
187
+ Constructs the YAML dictionary for MoE.
188
+ Maps prompts to experts if provided.
189
+ """
190
+ config = {
191
+ "base_model": base_model,
192
+ "gate_mode": gate_mode,
193
+ "dtype": dtype,
194
+ "tokenizer_source": tokenizer_source,
195
+ "experts": []
196
+ }
197
+
198
+ for i, exp in enumerate(experts):
199
+ expert_entry = {"source_model": exp}
200
+
201
+ # Map prompt if available
202
+ # "positive_prompts" is required for "hidden" gate mode
203
+ if i < len(prompts) and prompts[i].strip():
204
+ expert_entry["positive_prompts"] = [prompts[i].strip()]
205
+ # If hidden mode is forced but no prompt, we might fail validation
206
+ # But we leave it to the validator to complain if strictly required
207
+
208
+ config["experts"].append(expert_entry)
209
+
210
+ return config
211
+
212
+ def build_raw_config(method, models, base_model, dtype, weights):
213
+ """
214
+ Constructs the YAML for Raw PyTorch merging.
215
+ """
216
+ config = {
217
+ "merge_method": method.lower(),
218
+ "dtype": dtype,
219
+ "models": []
220
+ }
221
+
222
+ if base_model:
223
+ config["base_model"] = base_model
224
+
225
+ w_list = []
226
+ if weights:
227
+ try:
228
+ w_list = [float(x.strip()) for x in weights.split(',')]
229
+ except: pass
230
+
231
+ for i, m in enumerate(models):
232
+ entry = {"model": m, "parameters": {}}
233
+ # Most raw methods just use weight
234
+ entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
235
+ config["models"].append(entry)
236
+
237
+ return config