AlekseyCalvin commited on
Commit
0032edf
·
verified ·
1 Parent(s): 459f6e8

Update merge_utils.py

Browse files
Files changed (1) hide show
  1. merge_utils.py +38 -33
merge_utils.py CHANGED
@@ -4,12 +4,15 @@ import gc
4
  import torch
5
  import shutil
6
  import sys
 
7
  from pathlib import Path
8
 
 
 
 
9
  # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
10
  import pydantic
11
  from pydantic import ConfigDict, BaseModel
12
- # This forces Pydantic v2 to accept torch.Tensor as a valid type globally
13
  BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
14
 
15
  try:
@@ -30,11 +33,12 @@ except ImportError:
30
 
31
  def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
32
  """
33
- Executes a MergeKit run by intelligently detecting the config type.
34
  """
35
- # Force garbage collection before start
36
  gc.collect()
37
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
38
 
39
  # Shared Options
40
  merge_opts = MergeOptions(
@@ -43,7 +47,7 @@ def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
43
  lazy_unpickle=True,
44
  low_cpu_memory=True,
45
  max_shard_size=int(shard_gb * 1024**3),
46
- allow_crimes=True # Allow loose constraints
47
  )
48
 
49
  # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) ---
@@ -68,14 +72,13 @@ def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
68
  except Exception as e:
69
  raise RuntimeError(f"MoE Build Failed: {e}")
70
 
71
- # --- BRANCH 2: STANDARD MERGE (TIES, SLERP, ETC.) ---
72
  else:
73
  print("⚡ Detected Standard Merge Configuration.")
74
  try:
75
  # Validate using the Standard Schema
76
  conf = MergeConfiguration.model_validate(config_dict)
77
 
78
- # Execute using the standard runner
79
  run_merge(
80
  conf,
81
  out_path=out_path,
@@ -100,7 +103,6 @@ def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"):
100
  """
101
  print("🧠 Executing Raw PyTorch Merge...")
102
  try:
103
- # Validate using Raw Schema
104
  conf = RawPyTorchMergeConfig.model_validate(config_dict)
105
 
106
  merge_opts = MergeOptions(
@@ -111,7 +113,6 @@ def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"):
111
  safe_serialization=True
112
  )
113
 
114
- # Plan the merge tasks
115
  tasks = plan_flat_merge(
116
  conf,
117
  out_path,
@@ -120,11 +121,10 @@ def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"):
120
  options=merge_opts
121
  )
122
 
123
- # Execute the graph
124
  executor = Executor(
125
  tasks,
126
  math_device=device,
127
- storage_device="cpu" # Force storage to CPU for low-resource safety
128
  )
129
  executor.execute()
130
  print("✅ Raw PyTorch Merge Complete.")
@@ -138,9 +138,6 @@ def build_full_merge_config(
138
  method, models, base_model, weights, density,
139
  dtype, tokenizer_source, layer_ranges
140
  ):
141
- """
142
- Constructs the YAML dictionary for general merging (Linear, SLERP, TIES, etc.)
143
- """
144
  config = {
145
  "merge_method": method.lower(),
146
  "base_model": base_model if base_model else models[0],
@@ -153,22 +150,17 @@ def build_full_merge_config(
153
  if weights:
154
  try:
155
  w_list = [float(x.strip()) for x in weights.split(',')]
156
- except:
157
- pass
158
 
159
  for i, m in enumerate(models):
160
  entry = {"model": m, "parameters": {}}
161
-
162
- # Method Specific Param Injection
163
  if method.lower() in ["ties", "dare_ties", "dare_linear"]:
164
  entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
165
  entry["parameters"]["density"] = density
166
  elif method.lower() in ["slerp", "linear"]:
167
  entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
168
-
169
  config["models"].append(entry)
170
 
171
- # Inject Slices/Layer Ranges if provided
172
  if layer_ranges and layer_ranges.strip():
173
  try:
174
  extra_params = yaml.safe_load(layer_ranges)
@@ -181,11 +173,16 @@ def build_full_merge_config(
181
 
182
  def build_moe_config(
183
  base_model, experts, prompts, gate_mode, dtype,
184
- tokenizer_source
185
  ):
186
  """
187
  Constructs the YAML dictionary for MoE.
188
- Maps prompts to experts if provided.
 
 
 
 
 
189
  """
190
  config = {
191
  "base_model": base_model,
@@ -194,25 +191,34 @@ def build_moe_config(
194
  "tokenizer_source": tokenizer_source,
195
  "experts": []
196
  }
197
-
 
 
 
 
198
  for i, exp in enumerate(experts):
199
  expert_entry = {"source_model": exp}
200
 
201
- # Map prompt if available
202
- # "positive_prompts" is required for "hidden" gate mode
203
- if i < len(prompts) and prompts[i].strip():
 
204
  expert_entry["positive_prompts"] = [prompts[i].strip()]
205
- # If hidden mode is forced but no prompt, we might fail validation
206
- # But we leave it to the validator to complain if strictly required
207
 
208
  config["experts"].append(expert_entry)
209
-
 
 
 
 
 
 
 
 
 
210
  return config
211
 
212
  def build_raw_config(method, models, base_model, dtype, weights):
213
- """
214
- Constructs the YAML for Raw PyTorch merging.
215
- """
216
  config = {
217
  "merge_method": method.lower(),
218
  "dtype": dtype,
@@ -230,7 +236,6 @@ def build_raw_config(method, models, base_model, dtype, weights):
230
 
231
  for i, m in enumerate(models):
232
  entry = {"model": m, "parameters": {}}
233
- # Most raw methods just use weight
234
  entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
235
  config["models"].append(entry)
236
 
 
4
  import torch
5
  import shutil
6
  import sys
7
+ import warnings
8
  from pathlib import Path
9
 
10
+ # --- SILENCE PYDANTIC WARNINGS ---
11
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
12
+
13
  # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
14
  import pydantic
15
  from pydantic import ConfigDict, BaseModel
 
16
  BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
17
 
18
  try:
 
33
 
34
  def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
35
  """
36
+ Executes a MergeKit run.
37
  """
38
+ # Force garbage collection
39
  gc.collect()
40
+ if torch.cuda.is_available():
41
+ torch.cuda.empty_cache()
42
 
43
  # Shared Options
44
  merge_opts = MergeOptions(
 
47
  lazy_unpickle=True,
48
  low_cpu_memory=True,
49
  max_shard_size=int(shard_gb * 1024**3),
50
+ allow_crimes=True
51
  )
52
 
53
  # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) ---
 
72
  except Exception as e:
73
  raise RuntimeError(f"MoE Build Failed: {e}")
74
 
75
+ # --- BRANCH 2: STANDARD MERGE ---
76
  else:
77
  print("⚡ Detected Standard Merge Configuration.")
78
  try:
79
  # Validate using the Standard Schema
80
  conf = MergeConfiguration.model_validate(config_dict)
81
 
 
82
  run_merge(
83
  conf,
84
  out_path=out_path,
 
103
  """
104
  print("🧠 Executing Raw PyTorch Merge...")
105
  try:
 
106
  conf = RawPyTorchMergeConfig.model_validate(config_dict)
107
 
108
  merge_opts = MergeOptions(
 
113
  safe_serialization=True
114
  )
115
 
 
116
  tasks = plan_flat_merge(
117
  conf,
118
  out_path,
 
121
  options=merge_opts
122
  )
123
 
 
124
  executor = Executor(
125
  tasks,
126
  math_device=device,
127
+ storage_device="cpu"
128
  )
129
  executor.execute()
130
  print("✅ Raw PyTorch Merge Complete.")
 
138
  method, models, base_model, weights, density,
139
  dtype, tokenizer_source, layer_ranges
140
  ):
 
 
 
141
  config = {
142
  "merge_method": method.lower(),
143
  "base_model": base_model if base_model else models[0],
 
150
  if weights:
151
  try:
152
  w_list = [float(x.strip()) for x in weights.split(',')]
153
+ except: pass
 
154
 
155
  for i, m in enumerate(models):
156
  entry = {"model": m, "parameters": {}}
 
 
157
  if method.lower() in ["ties", "dare_ties", "dare_linear"]:
158
  entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
159
  entry["parameters"]["density"] = density
160
  elif method.lower() in ["slerp", "linear"]:
161
  entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
 
162
  config["models"].append(entry)
163
 
 
164
  if layer_ranges and layer_ranges.strip():
165
  try:
166
  extra_params = yaml.safe_load(layer_ranges)
 
173
 
174
  def build_moe_config(
175
  base_model, experts, prompts, gate_mode, dtype,
176
+ tokenizer_source, shared_experts=None
177
  ):
178
  """
179
  Constructs the YAML dictionary for MoE.
180
+
181
+ Key Logic based on MergeKit source:
182
+ - 'random'/'uniform_random' modes do NOT require prompts.
183
+ - 'hidden'/'cheap_embed' modes REQUIRE prompts.
184
+ - Qwen2 MoE requires exactly one shared expert.
185
+ - Mixtral requires ZERO shared experts.
186
  """
187
  config = {
188
  "base_model": base_model,
 
191
  "tokenizer_source": tokenizer_source,
192
  "experts": []
193
  }
194
+
195
+ # Handle Experts
196
+ if len(prompts) < len(experts):
197
+ prompts += [""] * (len(experts) - len(prompts))
198
+
199
  for i, exp in enumerate(experts):
200
  expert_entry = {"source_model": exp}
201
 
202
+ # Only attach prompts if they exist.
203
+ # mergekit.moe.config.is_bad_config will fail if prompts are missing
204
+ # BUT ONLY IF gate_mode != "random".
205
+ if prompts[i].strip():
206
  expert_entry["positive_prompts"] = [prompts[i].strip()]
 
 
207
 
208
  config["experts"].append(expert_entry)
209
+
210
+ # Handle Shared Experts (Required for Qwen2, Optional for DeepSeek)
211
+ if shared_experts:
212
+ config["shared_experts"] = []
213
+ for sh_exp in shared_experts:
214
+ # Shared experts usually don't use gating prompts in MergeKit implementations
215
+ # (DeepSeek forbids them, Qwen2 requires them if not random)
216
+ # We add a basic entry here; users might need advanced YAML editing for complex shared gating.
217
+ config["shared_experts"].append({"source_model": sh_exp})
218
+
219
  return config
220
 
221
  def build_raw_config(method, models, base_model, dtype, weights):
 
 
 
222
  config = {
223
  "merge_method": method.lower(),
224
  "dtype": dtype,
 
236
 
237
  for i, m in enumerate(models):
238
  entry = {"model": m, "parameters": {}}
 
239
  entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
240
  config["models"].append(entry)
241