AlekseyCalvin commited on
Commit
b5e9acb
·
verified ·
1 Parent(s): 12c0073

Create merge_utils.py

Browse files
Files changed (1) hide show
  1. merge_utils.py +128 -0
merge_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import gc
4
+ import torch
5
+ import shutil
6
+ from pathlib import Path
7
+
8
+ # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
9
+ import pydantic
10
+ from pydantic import ConfigDict, BaseModel
11
+ BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
12
+
13
+ try:
14
+ from mergekit.config import MergeConfiguration
15
+ from mergekit.merge import run_merge
16
+ from mergekit.architecture import get_architecture_info
17
+ except ImportError:
18
+ print("MergeKit not found. Please install 'mergekit' in requirements.txt")
19
+
20
+ def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
21
+ """
22
+ Executes a MergeKit run based on a dictionary config.
23
+ Optimized for CPU execution with aggressive sharding.
24
+ """
25
+ # Convert dict to YAML string first to ensure validation passes through standard flow
26
+ config_yaml = yaml.dump(config_dict)
27
+
28
+ print("--- Generated MergeKit Config ---")
29
+ print(config_yaml)
30
+ print("---------------------------------")
31
+
32
+ conf = MergeConfiguration.model_validate(yaml.safe_load(config_yaml))
33
+
34
+ run_merge(
35
+ conf,
36
+ out_path=out_path,
37
+ device=device,
38
+ low_cpu_mem=True,
39
+ copy_tokenizer=True,
40
+ lazy_unpickle=True,
41
+ max_shard_size=int(shard_gb * 1024**3)
42
+ )
43
+
44
+ # Force cleanup
45
+ gc.collect()
46
+
47
+ def build_full_merge_config(
48
+ method, models, base_model, weights, density,
49
+ dtype, tokenizer_source, layer_ranges
50
+ ):
51
+ """
52
+ Constructs the YAML dictionary for general merging (Linear, SLERP, TIES, etc.)
53
+ """
54
+ # Basic Config
55
+ config = {
56
+ "merge_method": method.lower(),
57
+ "base_model": base_model if base_model else models[0],
58
+ "dtype": dtype,
59
+ "tokenizer_source": tokenizer_source,
60
+ "models": []
61
+ }
62
+
63
+ # Helper to parse weights safely
64
+ w_list = []
65
+ if weights:
66
+ try:
67
+ w_list = [float(x.strip()) for x in weights.split(',')]
68
+ except:
69
+ print("Warning: Could not parse weights, defaulting to 1.0")
70
+
71
+ # Model Construction
72
+ for i, m in enumerate(models):
73
+ entry = {"model": m, "parameters": {}}
74
+
75
+ # Method Specific Param Injection
76
+ if method.lower() in ["ties", "dare_ties", "dare_linear"]:
77
+ entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
78
+ entry["parameters"]["density"] = density
79
+
80
+ elif method.lower() == "slerp":
81
+ # SLERP usually takes 't' parameter via weight, but often requires layer slices
82
+ # If layer_ranges is provided (JSON), use that. Otherwise use weight as 't'
83
+ if layer_ranges and "slices" in layer_ranges:
84
+ # Advanced Slice Config
85
+ pass # mergekit handles slices at root level usually, but we inject here if needed
86
+ else:
87
+ entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
88
+
89
+ elif method.lower() == "linear":
90
+ entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
91
+
92
+ config["models"].append(entry)
93
+
94
+ # Inject Slices/Layer Ranges if provided (Raw JSON override)
95
+ if layer_ranges.strip():
96
+ try:
97
+ extra_params = yaml.safe_load(layer_ranges)
98
+ if isinstance(extra_params, dict):
99
+ config.update(extra_params)
100
+ except Exception as e:
101
+ print(f"Error parsing layer ranges JSON: {e}")
102
+
103
+ return config
104
+
105
+ def build_moe_config(
106
+ base_model, experts, gate_mode, dtype,
107
+ tokenizer_source, positive_prompts=None
108
+ ):
109
+ """
110
+ Constructs the YAML dictionary for Mixture of Experts (MoE)
111
+ """
112
+ config = {
113
+ "base_model": base_model,
114
+ "gate_mode": gate_mode,
115
+ "dtype": dtype,
116
+ "tokenizer_source": tokenizer_source,
117
+ "experts": []
118
+ }
119
+
120
+ # Parse experts
121
+ for i, exp in enumerate(experts):
122
+ expert_entry = {
123
+ "source_model": exp,
124
+ "positive_prompts": [f"expert_{i}"] # Placeholder if not provided
125
+ }
126
+ config["experts"].append(expert_entry)
127
+
128
+ return config