File size: 3,408 Bytes
832111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import os
import gc
from safetensors.torch import load_file, save_file
from huggingface_hub import hf_hub_download, HfApi
from tqdm import tqdm
import numpy as np

def download_file(repo_id, filename, token, local_dir):
    return hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=local_dir)

def task_dare_custom(hf_token, base_repo, fine_tuned_repo, ratio, mask_rate, out_repo, private):
    """
    Custom DARE implementation:
    1. Load Base & Fine-Tuned
    2. Delta = FT - Base
    3. Mask = Bernoulli(1 - mask_rate)
    4. Rescale = 1 / (1 - mask_rate)
    5. New = Base + (Delta * Mask * Rescale * Ratio)
    """
    api = HfApi(token=hf_token)
    temp_dir = "./temp_dare"
    os.makedirs(temp_dir, exist_ok=True)

    try:
        # 1. Identify Model Files (Naive: get first .safetensors)
        print("Locating files...")
        base_files = api.list_repo_files(base_repo)
        base_sf = next((f for f in base_files if f.endswith(".safetensors") and "model" in f), None)
        
        ft_files = api.list_repo_files(fine_tuned_repo)
        ft_sf = next((f for f in ft_files if f.endswith(".safetensors") and "model" in f), None)

        if not base_sf or not ft_sf:
            return "Error: Could not locate .safetensors in one of the repos."

        print(f"Downloading {base_sf}...")
        base_path = download_file(base_repo, base_sf, hf_token, temp_dir)
        print(f"Downloading {ft_sf}...")
        ft_path = download_file(fine_tuned_repo, ft_sf, hf_token, temp_dir)

        # 2. Process
        print("Loading tensors...")
        base_sd = load_file(base_path, device="cpu")
        ft_sd = load_file(ft_path, device="cpu")
        
        merged_sd = {}
        keys = set(base_sd.keys()).intersection(ft_sd.keys())
        
        scale_factor = 1.0 / (1.0 - mask_rate)
        
        print("Applying DARE...")
        for k in tqdm(keys):
            b_tensor = base_sd[k]
            f_tensor = ft_sd[k]
            
            if b_tensor.shape != f_tensor.shape:
                merged_sd[k] = f_tensor # Fallback
                continue
            
            # Skip 1D tensors (LayerNorms usually) or non-float
            if len(b_tensor.shape) < 2 or not b_tensor.is_floating_point():
                merged_sd[k] = f_tensor # Keep FT version
                continue

            # Calculate Delta
            delta = f_tensor - b_tensor
            
            # Create Mask (Bernoulli)
            mask = torch.bernoulli(torch.full_like(delta, 1.0 - mask_rate))
            
            # Apply DARE formula
            # New = Base + (Delta * Mask * Scale * Ratio)
            final = b_tensor + (delta * mask * scale_factor * ratio)
            
            merged_sd[k] = final.to(torch.bfloat16) # Enforce BF16 for save

        # 3. Save
        out_path = os.path.join(temp_dir, "dare_merged.safetensors")
        save_file(merged_sd, out_path)
        
        # 4. Upload
        print("Uploading...")
        api.create_repo(repo_id=out_repo, private=private, exist_ok=True)
        api.upload_file(path_or_fileobj=out_path, path_in_repo="model.safetensors", repo_id=out_repo)
        
        return f"Done! Uploaded to {out_repo}"

    except Exception as e:
        return f"DARE Error: {e}"
    finally:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        gc.collect()