Soon_Merger_Toolkit / dare_utils.py
AlekseyCalvin's picture
Create dare_utils.py
9c6d670 verified
import torch
import os
import gc
from safetensors.torch import load_file, save_file
from huggingface_hub import hf_hub_download, HfApi
from tqdm import tqdm
import numpy as np
def download_file(repo_id, filename, token, local_dir):
return hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=local_dir)
def task_dare_custom(hf_token, base_repo, fine_tuned_repo, ratio, mask_rate, out_repo, private):
"""
Custom DARE implementation:
1. Load Base & Fine-Tuned
2. Delta = FT - Base
3. Mask = Bernoulli(1 - mask_rate)
4. Rescale = 1 / (1 - mask_rate)
5. New = Base + (Delta * Mask * Rescale * Ratio)
"""
api = HfApi(token=hf_token)
temp_dir = "./temp_dare"
os.makedirs(temp_dir, exist_ok=True)
try:
# 1. Identify Model Files (Naive: get first .safetensors)
print("Locating files...")
base_files = api.list_repo_files(base_repo)
base_sf = next((f for f in base_files if f.endswith(".safetensors") and "model" in f), None)
ft_files = api.list_repo_files(fine_tuned_repo)
ft_sf = next((f for f in ft_files if f.endswith(".safetensors") and "model" in f), None)
if not base_sf or not ft_sf:
return "Error: Could not locate .safetensors in one of the repos."
print(f"Downloading {base_sf}...")
base_path = download_file(base_repo, base_sf, hf_token, temp_dir)
print(f"Downloading {ft_sf}...")
ft_path = download_file(fine_tuned_repo, ft_sf, hf_token, temp_dir)
# 2. Process
print("Loading tensors...")
base_sd = load_file(base_path, device="cpu")
ft_sd = load_file(ft_path, device="cpu")
merged_sd = {}
keys = set(base_sd.keys()).intersection(ft_sd.keys())
scale_factor = 1.0 / (1.0 - mask_rate)
print("Applying DARE...")
for k in tqdm(keys):
b_tensor = base_sd[k]
f_tensor = ft_sd[k]
if b_tensor.shape != f_tensor.shape:
merged_sd[k] = f_tensor # Fallback
continue
# Skip 1D tensors (LayerNorms usually) or non-float
if len(b_tensor.shape) < 2 or not b_tensor.is_floating_point():
merged_sd[k] = f_tensor # Keep FT version
continue
# Calculate Delta
delta = f_tensor - b_tensor
# Create Mask (Bernoulli)
mask = torch.bernoulli(torch.full_like(delta, 1.0 - mask_rate))
# Apply DARE formula
# New = Base + (Delta * Mask * Scale * Ratio)
final = b_tensor + (delta * mask * scale_factor * ratio)
merged_sd[k] = final.to(torch.bfloat16) # Enforce BF16 for save
# 3. Save
out_path = os.path.join(temp_dir, "dare_merged.safetensors")
save_file(merged_sd, out_path)
# 4. Upload
print("Uploading...")
api.create_repo(repo_id=out_repo, private=private, exist_ok=True)
api.upload_file(path_or_fileobj=out_path, path_in_repo="model.safetensors", repo_id=out_repo)
return f"Done! Uploaded to {out_repo}"
except Exception as e:
return f"DARE Error: {e}"
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
gc.collect()