File size: 13,425 Bytes
656b04b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 | # Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: BUSL-1.1
import logging
import os
import time
from typing import List, Optional
import click
import cma
import numpy as np
import pandas
import ray
import torch
import tqdm
import transformers
import yaml
try:
import wandb
except ImportError:
wandb = None
from mergekit.common import ModelReference
from mergekit.evo.config import (
EvolMergeConfiguration,
ModelGenomeDefinition,
check_for_naughty_config,
)
from mergekit.evo.genome import ModelGenome
from mergekit.evo.strategy import (
ActorPoolEvaluationStrategy,
BufferedRayEvaluationStrategy,
SerialEvaluationStrategy,
)
from mergekit.merge import run_merge
from mergekit.options import MergeOptions
@click.command("mergekit-evolve")
@click.argument("genome-config-path", type=str)
@click.option("--max-fevals", type=int, default=100)
@click.option("--vllm/--no-vllm", is_flag=True, default=False, help="Use vLLM")
@click.option(
"--strategy",
"-s",
type=click.Choice(["pool", "buffered", "serial"]),
default="pool",
help="Evaluation scheduling strategy",
)
@click.option(
"--in-memory/--no-in-memory",
is_flag=True,
default=False,
help="Use in-memory merge & evaluation",
)
@click.option(
"--storage-path",
type=str,
help="Path to storage accessible to all nodes for model storage",
required=True,
)
@click.option("--num-gpus", type=int, help="Number of GPUs to use across all nodes")
@click.option("--merge-cuda/--no-merge-cuda", is_flag=True, default=True)
@click.option("--trust-remote-code/--no-trust-remote-code", is_flag=True, default=False)
@click.option("--allow-crimes/--no-allow-crimes", is_flag=True, default=False)
@click.option("--random-seed", type=int, default=0)
@click.option("--batch-size", type=int, default=None, help="Batch size for evaluation")
@click.option("--sigma0", type=float, default=1 / 6, help="Initial sigma for CMA-ES")
@click.option("use_wandb", "--wandb/--no-wandb", is_flag=True, default=False)
@click.option("--wandb-project", type=str, help="Wandb project name")
@click.option("--wandb-entity", type=str, help="Wandb entity name")
@click.option(
"--task-search-path",
type=str,
multiple=True,
help="Path to search for lmeval tasks",
)
@click.option(
"--i-understand-the-depths-of-the-evils-i-am-unleashing",
"allow_benchmark_tasks",
is_flag=True,
default=False,
help="Allow benchmark tasks as objectives",
)
@click.option(
"--save-final-model/--no-save-final-model",
is_flag=True,
default=True,
help="Save the final merged model",
)
@click.option(
"--reshard/--no-reshard",
is_flag=True,
default=True,
help="Convert models to single-shard safetensors for faster merge",
)
@click.option(
"--timeout",
type=float,
default=None,
help="Maximum time to run the optimization in seconds",
)
@click.option(
"--load-in-8bit",
is_flag=True,
default=False,
help="Evaluate models at 8-bit precision",
)
@click.option(
"--load-in-4bit",
is_flag=True,
default=False,
help="Evaluate models at 4-bit precision",
)
@click.option(
"--force-population-size",
type=int,
default=None,
help="Force a specific initial population size for CMA-ES",
)
def main(
genome_config_path: str,
max_fevals: int,
vllm: bool,
strategy: str,
in_memory: bool,
storage_path: Optional[str],
num_gpus: Optional[int],
merge_cuda: bool,
trust_remote_code: bool,
allow_crimes: bool,
random_seed: int,
batch_size: Optional[int],
sigma0: float,
use_wandb: bool,
wandb_project: Optional[str],
wandb_entity: Optional[str],
task_search_path: List[str],
allow_benchmark_tasks: bool,
save_final_model: bool,
reshard: bool,
timeout: Optional[float],
load_in_8bit: bool,
load_in_4bit: bool,
force_population_size: Optional[int],
):
config = EvolMergeConfiguration.model_validate(
yaml.safe_load(open(genome_config_path, "r", encoding="utf-8"))
)
check_for_naughty_config(config, allow=allow_benchmark_tasks)
if load_in_4bit and load_in_8bit:
raise ValueError("Cannot load models in both 4-bit and 8-bit")
if load_in_4bit or load_in_8bit:
if vllm:
raise ValueError("Cannot use vLLM with 4-bit or 8-bit models")
if in_memory:
raise ValueError("Cannot use in-memory mode with 4-bit or 8-bit models")
try:
import bitsandbytes
except ImportError:
raise RuntimeError("bitsandbytes is not installed")
bnb_config = transformers.BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
else:
bnb_config = None
if use_wandb:
if not wandb:
raise RuntimeError("wandb is not installed")
run = wandb.init(
project=wandb_project or "mergekit-evolve",
entity=wandb_entity,
config=config.model_dump(mode="json"),
)
else:
run = None
merge_options = MergeOptions(
transformers_cache=os.path.join(storage_path, "transformers_cache"),
lora_merge_cache=os.path.join(storage_path, "lora_merge_cache"),
cuda=merge_cuda,
low_cpu_memory=merge_cuda and not in_memory,
out_shard_size=1_000_000_000_000, # one trillion bytes!
trust_remote_code=trust_remote_code,
allow_crimes=allow_crimes,
random_seed=random_seed,
quiet=True,
read_to_gpu=merge_cuda and not in_memory,
copy_tokenizer=True,
safe_serialization=True,
)
# convert models to single-shard safetensors
if reshard:
resharded_models = []
resharded_base = None
for model in tqdm.tqdm(config.genome.models, desc="Resharding models"):
resharded_models.append(
_reshard_model(
model,
storage_path,
merge_options.lora_merge_cache,
trust_remote_code,
)
)
if config.genome.base_model is not None:
resharded_base = _reshard_model(
config.genome.base_model,
storage_path,
merge_options.lora_merge_cache,
trust_remote_code,
)
else:
resharded_models = config.genome.models
resharded_base = config.genome.base_model
genome = ModelGenome(
ModelGenomeDefinition.model_validate(
{
**config.genome.model_dump(
exclude=[
"models",
"base_model",
]
),
"models": resharded_models,
"base_model": resharded_base,
}
),
trust_remote_code=trust_remote_code,
)
if strategy == "pool":
strat_cls = ActorPoolEvaluationStrategy
elif strategy == "buffered":
strat_cls = BufferedRayEvaluationStrategy
elif strategy == "serial":
strat_cls = SerialEvaluationStrategy
else:
raise ValueError(f"Unknown strategy {strategy}")
strat = strat_cls(
config,
genome,
merge_options,
num_gpus=num_gpus,
vllm=vllm,
in_memory=in_memory,
model_storage_path=os.path.join(storage_path, "merged"),
batch_size=batch_size,
task_search_path=task_search_path,
quantization_config=bnb_config,
)
x0 = genome.initial_genotype(random=config.random_init).view(-1).numpy()
xbest = x0
xbest_cost = np.inf
def progress_callback(es: cma.CMAEvolutionStrategy):
nonlocal xbest, xbest_cost
res = es.result
if use_wandb:
best_params = genome.genotype_to_param_arrays(res.xbest)
mean_params = genome.genotype_to_param_arrays(res.xfavorite)
run.log(
{
"best_score": -res.fbest,
"best_genome": wandb.Table(data=pandas.DataFrame(best_params)),
"mean_genome": wandb.Table(data=pandas.DataFrame(mean_params)),
"mean_std": genome.genotype_to_param_arrays(res.stds),
"evaluations": res.evaluations,
},
commit=True,
step=res.evaluations,
)
if res.fbest < xbest_cost:
xbest = res.xbest
xbest_cost = res.fbest
print(f"New best score: {-xbest_cost:.4f}")
best_yaml = genome.genotype_merge_config(xbest).to_yaml()
with open(os.path.join(storage_path, "best_config.yaml"), "w") as f:
f.write(best_yaml)
print(f"Merge configuration:\n{best_yaml}")
if use_wandb:
art = wandb.Artifact("best_config", type="merge_config")
art.add_file(os.path.join(storage_path, "best_config.yaml"))
run.log_artifact(art)
def parallel_evaluate(x: List[np.ndarray]) -> List[float]:
print(f"Received {len(x)} genotypes")
res = strat.evaluate_genotypes(x)
if use_wandb:
res = list(res)
score_mean = np.mean([r["score"] for r in res])
score_std = np.std([r["score"] for r in res])
run.log(
{
"population/score_mean": score_mean,
"population/score_std": score_std,
},
commit=False,
)
for task in res[0]["results"]:
for metric in res[0]["results"][task]:
values = [r["results"][task][metric] for r in res]
values = [v for v in values if v is not None]
if not values or all(isinstance(v, str) for v in values):
continue
mean = np.mean(values)
max_val = max(values)
min_val = min(values)
metric_pretty = metric.replace(",none", "")
if metric_pretty.endswith("_stderr"):
# don't log stats for stderr that's just silly
continue
run.log(
{
f"population/{task}_{metric_pretty}_mean": mean,
f"population/{task}_{metric_pretty}_max": max_val,
f"population/{task}_{metric_pretty}_min": min_val,
},
commit=False,
)
return [-x["score"] for x in res] # maximize
try:
cma_opts = {"maxfevals": max_fevals, "timeout": timeout}
if force_population_size is not None:
cma_opts["popsize"] = force_population_size
xbest, es = cma.fmin2(
None,
parallel_objective=parallel_evaluate,
x0=x0,
sigma0=sigma0,
options=cma_opts,
callback=progress_callback,
)
xbest_cost = es.result.fbest
except KeyboardInterrupt:
ray.shutdown()
print("!!! OPTIMIZATION COMPLETE !!!")
print(f"Best cost: {xbest_cost:.4f}")
print()
# pause for a bit to let any CUDA-using processes clean up
time.sleep(1.0)
# save the best merge configuration using original model references
genome_pretty = ModelGenome(config.genome, trust_remote_code=trust_remote_code)
best_config = genome_pretty.genotype_merge_config(xbest)
print("Best merge configuration:")
print(best_config.to_yaml())
if save_final_model:
print("Saving final model...")
run_merge(best_config, os.path.join(storage_path, "final_model"), merge_options)
def _reshard_model(
model: ModelReference, storage_path: str, merge_cache: str, trust_remote_code: bool
) -> ModelReference:
merged = model.merged(
cache_dir=merge_cache,
trust_remote_code=trust_remote_code,
lora_merge_dtype="bfloat16",
)
out_path = os.path.join(
storage_path,
"input_models",
merged.model._unique_id(),
)
if os.path.exists(out_path):
logging.info(f"Using existing resharded model at {out_path}")
return ModelReference(model=out_path)
model_hf = transformers.AutoModelForCausalLM.from_pretrained(
merged.model.path,
revision=merged.model.revision,
trust_remote_code=trust_remote_code,
torch_dtype=torch.bfloat16,
cache_dir=os.path.join(storage_path, "transformers_cache"),
)
model_hf.save_pretrained(
out_path, safe_serialization=True, out_shard_size=1_000_000_000_000
)
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.model.path,
revision=model.model.revision,
trust_remote_code=trust_remote_code,
use_fast=True,
)
tokenizer.save_pretrained(out_path)
except Exception as e:
logging.warning(f"Could not save tokenizer for {model.model}", exc_info=e)
return ModelReference(model=out_path)
if __name__ == "__main__":
main()
|