AlekseyCalvin commited on
Commit
832111c
·
verified ·
1 Parent(s): b5e9acb

Create dare_utils.py

Browse files
Files changed (1) hide show
  1. dare_utils.py +94 -0
dare_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import gc
4
+ from safetensors.torch import load_file, save_file
5
+ from huggingface_hub import hf_hub_download, HfApi
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+
9
+ def download_file(repo_id, filename, token, local_dir):
10
+ return hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=local_dir)
11
+
12
+ def task_dare_custom(hf_token, base_repo, fine_tuned_repo, ratio, mask_rate, out_repo, private):
13
+ """
14
+ Custom DARE implementation:
15
+ 1. Load Base & Fine-Tuned
16
+ 2. Delta = FT - Base
17
+ 3. Mask = Bernoulli(1 - mask_rate)
18
+ 4. Rescale = 1 / (1 - mask_rate)
19
+ 5. New = Base + (Delta * Mask * Rescale * Ratio)
20
+ """
21
+ api = HfApi(token=hf_token)
22
+ temp_dir = "./temp_dare"
23
+ os.makedirs(temp_dir, exist_ok=True)
24
+
25
+ try:
26
+ # 1. Identify Model Files (Naive: get first .safetensors)
27
+ print("Locating files...")
28
+ base_files = api.list_repo_files(base_repo)
29
+ base_sf = next((f for f in base_files if f.endswith(".safetensors") and "model" in f), None)
30
+
31
+ ft_files = api.list_repo_files(fine_tuned_repo)
32
+ ft_sf = next((f for f in ft_files if f.endswith(".safetensors") and "model" in f), None)
33
+
34
+ if not base_sf or not ft_sf:
35
+ return "Error: Could not locate .safetensors in one of the repos."
36
+
37
+ print(f"Downloading {base_sf}...")
38
+ base_path = download_file(base_repo, base_sf, hf_token, temp_dir)
39
+ print(f"Downloading {ft_sf}...")
40
+ ft_path = download_file(fine_tuned_repo, ft_sf, hf_token, temp_dir)
41
+
42
+ # 2. Process
43
+ print("Loading tensors...")
44
+ base_sd = load_file(base_path, device="cpu")
45
+ ft_sd = load_file(ft_path, device="cpu")
46
+
47
+ merged_sd = {}
48
+ keys = set(base_sd.keys()).intersection(ft_sd.keys())
49
+
50
+ scale_factor = 1.0 / (1.0 - mask_rate)
51
+
52
+ print("Applying DARE...")
53
+ for k in tqdm(keys):
54
+ b_tensor = base_sd[k]
55
+ f_tensor = ft_sd[k]
56
+
57
+ if b_tensor.shape != f_tensor.shape:
58
+ merged_sd[k] = f_tensor # Fallback
59
+ continue
60
+
61
+ # Skip 1D tensors (LayerNorms usually) or non-float
62
+ if len(b_tensor.shape) < 2 or not b_tensor.is_floating_point():
63
+ merged_sd[k] = f_tensor # Keep FT version
64
+ continue
65
+
66
+ # Calculate Delta
67
+ delta = f_tensor - b_tensor
68
+
69
+ # Create Mask (Bernoulli)
70
+ mask = torch.bernoulli(torch.full_like(delta, 1.0 - mask_rate))
71
+
72
+ # Apply DARE formula
73
+ # New = Base + (Delta * Mask * Scale * Ratio)
74
+ final = b_tensor + (delta * mask * scale_factor * ratio)
75
+
76
+ merged_sd[k] = final.to(torch.bfloat16) # Enforce BF16 for save
77
+
78
+ # 3. Save
79
+ out_path = os.path.join(temp_dir, "dare_merged.safetensors")
80
+ save_file(merged_sd, out_path)
81
+
82
+ # 4. Upload
83
+ print("Uploading...")
84
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True)
85
+ api.upload_file(path_or_fileobj=out_path, path_in_repo="model.safetensors", repo_id=out_repo)
86
+
87
+ return f"Done! Uploaded to {out_repo}"
88
+
89
+ except Exception as e:
90
+ return f"DARE Error: {e}"
91
+ finally:
92
+ if os.path.exists(temp_dir):
93
+ shutil.rmtree(temp_dir)
94
+ gc.collect()