Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import gc | |
| from safetensors.torch import load_file, save_file | |
| from huggingface_hub import hf_hub_download, HfApi | |
| from tqdm import tqdm | |
| import numpy as np | |
| def download_file(repo_id, filename, token, local_dir): | |
| return hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=local_dir) | |
| def task_dare_custom(hf_token, base_repo, fine_tuned_repo, ratio, mask_rate, out_repo, private): | |
| """ | |
| Custom DARE implementation: | |
| 1. Load Base & Fine-Tuned | |
| 2. Delta = FT - Base | |
| 3. Mask = Bernoulli(1 - mask_rate) | |
| 4. Rescale = 1 / (1 - mask_rate) | |
| 5. New = Base + (Delta * Mask * Rescale * Ratio) | |
| """ | |
| api = HfApi(token=hf_token) | |
| temp_dir = "./temp_dare" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| try: | |
| # 1. Identify Model Files (Naive: get first .safetensors) | |
| print("Locating files...") | |
| base_files = api.list_repo_files(base_repo) | |
| base_sf = next((f for f in base_files if f.endswith(".safetensors") and "model" in f), None) | |
| ft_files = api.list_repo_files(fine_tuned_repo) | |
| ft_sf = next((f for f in ft_files if f.endswith(".safetensors") and "model" in f), None) | |
| if not base_sf or not ft_sf: | |
| return "Error: Could not locate .safetensors in one of the repos." | |
| print(f"Downloading {base_sf}...") | |
| base_path = download_file(base_repo, base_sf, hf_token, temp_dir) | |
| print(f"Downloading {ft_sf}...") | |
| ft_path = download_file(fine_tuned_repo, ft_sf, hf_token, temp_dir) | |
| # 2. Process | |
| print("Loading tensors...") | |
| base_sd = load_file(base_path, device="cpu") | |
| ft_sd = load_file(ft_path, device="cpu") | |
| merged_sd = {} | |
| keys = set(base_sd.keys()).intersection(ft_sd.keys()) | |
| scale_factor = 1.0 / (1.0 - mask_rate) | |
| print("Applying DARE...") | |
| for k in tqdm(keys): | |
| b_tensor = base_sd[k] | |
| f_tensor = ft_sd[k] | |
| if b_tensor.shape != f_tensor.shape: | |
| merged_sd[k] = f_tensor # Fallback | |
| continue | |
| # Skip 1D tensors (LayerNorms usually) or non-float | |
| if len(b_tensor.shape) < 2 or not b_tensor.is_floating_point(): | |
| merged_sd[k] = f_tensor # Keep FT version | |
| continue | |
| # Calculate Delta | |
| delta = f_tensor - b_tensor | |
| # Create Mask (Bernoulli) | |
| mask = torch.bernoulli(torch.full_like(delta, 1.0 - mask_rate)) | |
| # Apply DARE formula | |
| # New = Base + (Delta * Mask * Scale * Ratio) | |
| final = b_tensor + (delta * mask * scale_factor * ratio) | |
| merged_sd[k] = final.to(torch.bfloat16) # Enforce BF16 for save | |
| # 3. Save | |
| out_path = os.path.join(temp_dir, "dare_merged.safetensors") | |
| save_file(merged_sd, out_path) | |
| # 4. Upload | |
| print("Uploading...") | |
| api.create_repo(repo_id=out_repo, private=private, exist_ok=True) | |
| api.upload_file(path_or_fileobj=out_path, path_in_repo="model.safetensors", repo_id=out_repo) | |
| return f"Done! Uploaded to {out_repo}" | |
| except Exception as e: | |
| return f"DARE Error: {e}" | |
| finally: | |
| if os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir) | |
| gc.collect() |