File size: 1,385 Bytes
73ca183
3c33946
 
 
 
 
 
 
 
 
 
 
 
 
 
73ca183
49dc183
 
 
 
 
3c33946
 
 
 
 
 
 
49dc183
 
 
 
 
 
 
 
 
 
 
 
 
92792f7
c1b23ca
49dc183
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pydantic
from pydantic import ConfigDict, BaseModel

# --- GLOBAL FIX FOR PYDANTIC V2 / MERGEKIT COMPATIBILITY ---
# This MUST happen before importing any mergekit modules
def patch_pydantic():
    # Force the default configuration for all future models
    original_init = BaseModel.__init__
    def patched_init(self, *args, **kwargs):
        original_init(self, *args, **kwargs)
    
    BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)

patch_pydantic()
# ---------------------------------------------------------

import yaml
import os
import torch
from pathlib import Path

# Now it is safe to import mergekit
try:
    from mergekit.config import MergeConfiguration
    from mergekit.merge import run_merge
except ImportError as e:
    print(f"Error importing mergekit: {e}")

def execute_mergekit(config_dict, out_path, hf_token):
    """Runs a MergeKit operation using a dictionary config."""
    # Convert dict to YAML string for MergeKit's parser
    config_yaml = yaml.dump(config_dict)
    conf = MergeConfiguration.model_validate(yaml.safe_load(config_yaml))
    
    # Execute merge with CPU-specific optimizations
    run_merge(
        conf,
        out_path,
        device="cpu",
        low_cpu_mem=True,
        copy_tokenizer=True,
        lazy_unpickle=True,
        max_shard_size=int(shard_gb * 1024 * 1024 * 1024)
    )
    return True