| from tqdm import tqdm |
| import argparse |
| import requests |
| import merge |
| import os |
| import sys |
| import shutil |
| import yaml |
| from pathlib import Path |
| import gradio as gr |
|
|
| def parse_arguments(): |
| parser = argparse.ArgumentParser(description="Merge HuggingFace models") |
| parser.add_argument('repo_list', type=str, help='File containing list of repositories to merge, supports mergekit yaml or txt') |
| parser.add_argument('output_dir', type=str, help='Directory for the merged models') |
| parser.add_argument('-base_model', type=str, default='staging/base_model', help='Base model directory') |
| parser.add_argument('-staging_model', type=str, default='staging/merge_model', help='Staging model directory') |
| parser.add_argument('-p', type=float, default=0.5, help='Dropout probability') |
| parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta') |
| parser.add_argument('--dry', action='store_true', help='Run in dry mode without making any changes') |
| return parser.parse_args() |
|
|
| def repo_list_generator(file_path, default_p, default_lambda_val): |
| _, file_extension = os.path.splitext(file_path) |
|
|
| |
| if file_extension.lower() == '.yaml' or file_extension.lower() == ".yml": |
| with open(file_path, 'r', encoding='utf-8') as file: |
| data = yaml.safe_load(file) |
| for model_info in data['models']: |
| model_name = model_info['model'] |
| p = model_info.get('parameters', {}).get('weight', default_p) |
| lambda_val = 1 / model_info.get('parameters', {}).get('density', default_lambda_val) |
| yield model_name, p, lambda_val |
|
|
| else: |
| with open(file_path, "r", encoding='utf-8') as file: |
| repos_to_process = file.readlines() |
| for repo in repos_to_process: |
| yield repo.strip(), default_p, default_lambda_val |
|
|
| def reset_directories(directories, dry_run): |
| for directory in directories: |
| if os.path.exists(directory): |
| if dry_run: |
| print(f"[DRY RUN] Would delete directory {directory}") |
| else: |
| shutil.rmtree(directory) |
| print(f"Directory {directory} deleted successfully.") |
|
|
| def do_merge(tensor_map, staging_path, p, lambda_val, dry_run=False): |
| if dry_run: |
| print(f"[DRY RUN] Would merge with {staging_path}") |
| else: |
| try: |
| print(f"Merge operation for {staging_path}") |
| tensor_map = merge.merge_folder(tensor_map, staging_path, p, lambda_val) |
| print("Merge operation completed successfully.") |
| except Exception as e: |
| print(f"Error during merge operation: {e}") |
| return tensor_map |
|
|
| def do_merge_files(base_path, staging_path, output_path, p, lambda_val, dry_run=False): |
| if dry_run: |
| print(f"[DRY RUN] Would merge with {staging_path}") |
| else: |
| try: |
| print(f"Merge operation for {staging_path}") |
| tensor_map = merge.merge_files(base_path, staging_path, output_path, p, lambda_val) |
| print("Merge operation completed successfully.") |
| except Exception as e: |
| print(f"Error during merge operation: {e}") |
| return tensor_map |
|
|
| def do_merge_diffusers(tensor_map, staging_path, p, lambda_val, skip_dirs, dry_run=False): |
| if dry_run: |
| print(f"[DRY RUN] Would merge with {staging_path}") |
| else: |
| try: |
| print(f"Merge operation for {staging_path}") |
| tensor_map = merge.merge_folder_diffusers(tensor_map, staging_path, p, lambda_val, skip_dirs) |
| print("Merge operation completed successfully.") |
| except Exception as e: |
| print(f"Error during merge operation: {e}") |
| return tensor_map |
|
|
| def download_repo(repo_name, path, dry_run=False): |
| from huggingface_hub import snapshot_download |
| if dry_run: |
| print(f"[DRY RUN] Would download repository {repo_name} to {path}") |
| else: |
| print(f"Repository {repo_name} cloning.") |
| try: |
| snapshot_download(repo_id=repo_name, local_dir=path) |
| except Exception as e: |
| print(e) |
| return |
| print(f"Repository {repo_name} cloned successfully.") |
|
|
| def download_thing(directory, url, progress=gr.Progress(track_tqdm=True)): |
| civitai_api_key= os.environ.get("CIVITAI_API_KEY") |
| url = url.strip() |
| if "drive.google.com" in url: |
| original_dir = os.getcwd() |
| os.chdir(directory) |
| os.system(f"gdown --fuzzy {url}") |
| os.chdir(original_dir) |
| elif "huggingface.co" in url: |
| url = url.replace("?download=true", "") |
| if "/blob/" in url: |
| url = url.replace("/blob/", "/resolve/") |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") |
| else: |
| os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") |
| elif "civitai.com" in url: |
| if "?" in url: |
| url = url.split("?")[0] |
| if civitai_api_key: |
| url = url + f"?token={civitai_api_key}" |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") |
| else: |
| print("You need an API key to download Civitai models.") |
| else: |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") |
|
|
| def get_local_model_list(dir_path): |
| model_list = [] |
| valid_extensions = ('.safetensors') |
| for file in Path(dir_path).glob("*"): |
| if file.suffix in valid_extensions: |
| file_path = str(Path(f"{dir_path}/{file.name}")) |
| model_list.append(file_path) |
| return model_list |
|
|
| def list_sub(a, b): |
| return [e for e in a if e not in b] |
|
|
| def get_download_file(temp_dir, url): |
| new_file = None |
| if not "http" in url and Path(url).exists(): |
| print(f"Use local file: {url}") |
| new_file = url |
| elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists(): |
| print(f"File to download alreday exists: {url}") |
| new_file = f"{temp_dir}/{url.split('/')[-1]}" |
| else: |
| print(f"Start downloading: {url}") |
| before = get_local_model_list(temp_dir) |
| try: |
| download_thing(temp_dir, url.strip()) |
| except Exception: |
| print(f"Download failed: {url}") |
| return None |
| after = get_local_model_list(temp_dir) |
| new_file = list_sub(after, before)[0] if list_sub(after, before) else None |
| if new_file is None: |
| print(f"Download failed: {url}") |
| return None |
| print(f"Download completed: {url}") |
| return new_file |
|
|
| def download_file(url, path, dry_run=False): |
| if dry_run: |
| print(f"[DRY RUN] Would download file {url} to {path}") |
| else: |
| print(f"File {url} cloning.") |
| try: |
| path = get_download_file(path, url) |
| except Exception as e: |
| print(e) |
| return None |
| print(f"File {url} cloned successfully.") |
| return path |
|
|
| def is_repo_name(s): |
| import re |
| return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s) |
|
|
| def should_create_symlink(repo_name): |
| if os.path.exists(repo_name): |
| return True, os.path.isfile(repo_name) |
| return False, False |
|
|
| def download_or_link_repo(repo_name, path, dry_run=False): |
| symlink, is_file = should_create_symlink(repo_name) |
|
|
| if symlink and is_file: |
| os.makedirs(path, exist_ok=True) |
| symlink_path = os.path.join(path, os.path.basename(repo_name)) |
| os.symlink(repo_name, symlink_path) |
| elif symlink: |
| os.symlink(repo_name, path) |
| elif "http" in repo_name: |
| return download_file(repo_name, path, dry_run) |
| elif is_repo_name(repo_name): |
| download_repo(repo_name, path, dry_run) |
| return None |
|
|
| def delete_repo(path, dry_run=False): |
| if dry_run: |
| print(f"[DRY RUN] Would delete repository at {path}") |
| else: |
| try: |
| shutil.rmtree(path) |
| print(f"Repository at {path} deleted successfully.") |
| except Exception as e: |
| print(f"Error deleting repository at {path}: {e}") |
|
|
| def get_max_vocab_size(repo_list): |
| max_vocab_size = 0 |
| repo_with_max_vocab = None |
|
|
| for repo in repo_list: |
| repo_name = repo[0].strip() |
| url = f"https://huggingface.co/{repo_name}/raw/main/config.json" |
|
|
| try: |
| response = requests.get(url) |
| response.raise_for_status() |
| config = response.json() |
| vocab_size = config.get("vocab_size", 0) |
|
|
| if vocab_size > max_vocab_size: |
| max_vocab_size = vocab_size |
| repo_with_max_vocab = repo_name |
|
|
| except requests.RequestException as e: |
| print(f"Error fetching data from {url}: {e}") |
|
|
| return max_vocab_size, repo_with_max_vocab |
|
|
| def download_json_files(repo_name, file_paths, output_dir): |
| base_url = f"https://huggingface.co/{repo_name}/raw/main/" |
|
|
| for file_path in file_paths: |
| url = base_url + file_path |
| response = requests.get(url) |
| if response.status_code == 200: |
| with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file: |
| file.write(response.content) |
| else: |
| print(f"Failed to download {file_path}") |
|
|
| def get_merged_path(filename, output_dir): |
| from datetime import datetime, timezone, timedelta |
| dt_now = datetime.now(timezone(timedelta(hours=9))) |
| basename = dt_now.strftime('Merged_%Y%m%d_%H%M') |
| ext = Path(filename).suffix |
| return str(Path(output_dir, basename + ext)), str(Path(output_dir, basename + ".yaml")) |
|
|
| def repo_list_to_yaml(repo_list_path, repo_list, output_yaml_path): |
| if Path(repo_list_path).suffix.lower() in (".yaml", ".yml"): |
| shutil.copy(repo_list_path, output_yaml_path) |
| else: |
| repos = list(repo_list) |
| yaml_dict = {} |
| yaml_dict.setdefault('models', {}) |
| for repo in repos: |
| model, weight, density = repo |
| model_info = {} |
| model_info['model'] = str(model) |
| model_info.setdefault('parameters', {}) |
| model_info['parameters']['weight'] = float(weight) |
| model_info['parameters']['density'] = float(density) |
| yaml_dict['models'][str(model.split("/")[-1])] = model_info |
| with open(output_yaml_path, mode='w', encoding='utf-8') as file: |
| yaml.dump(yaml_dict, file, default_flow_style=False, allow_unicode=True) |
|
|
| def process_repos(output_dir, base_model, staging_model, repo_list_file, p, lambda_val, skip_dirs, dry_run=False, progress=gr.Progress(track_tqdm=True)): |
| repo_type = "Default" |
| |
| if os.path.exists(output_dir): |
| sys.exit(f"Output directory '{output_dir}' already exists. Exiting to prevent data loss.") |
|
|
| |
| reset_directories([base_model, staging_model], dry_run) |
|
|
| |
| os.makedirs(base_model, exist_ok=True) |
| os.makedirs(staging_model, exist_ok=True) |
|
|
| repo_list_gen = repo_list_generator(repo_list_file, p, lambda_val) |
|
|
| repos_to_process = list(repo_list_gen) |
|
|
| |
| path = download_or_link_repo(repos_to_process[0][0].strip(), base_model, dry_run) |
| if path is not None and (".safetensors" in path or ".sft" in path): repo_type = "Files" |
| elif Path(base_model, "model_index.json").exists(): repo_type = "Diffusers" |
| if repo_type == "Files": |
| os.makedirs(output_dir, exist_ok=True) |
| output_file_path, output_yaml_path = get_merged_path(path, output_dir) |
| repo_list_to_yaml(repo_list_file, repo_list_gen, output_yaml_path) |
| for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Files')): |
| repo_name = repo[0].strip() |
| repo_p = repo[1] |
| repo_lambda = repo[2] |
| delete_repo(staging_model, dry_run) |
| staging_path = download_or_link_repo(repo_name, staging_model, dry_run) |
| do_merge_files(path, staging_path, output_file_path, repo_p, repo_lambda, dry_run) |
| reset_directories([base_model, staging_model], dry_run) |
| return output_file_path, output_yaml_path |
| elif repo_type == "Diffusers": |
| merge.copy_dirs(base_model, output_dir) |
| tensor_map = merge.map_tensors_to_files_diffusers(base_model, skip_dirs) |
|
|
| for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')): |
| repo_name = repo[0].strip() |
| repo_p = repo[1] |
| repo_lambda = repo[2] |
| delete_repo(staging_model, dry_run) |
| download_or_link_repo(repo_name, staging_model, dry_run) |
| tensor_map = do_merge_diffusers(tensor_map, staging_model, repo_p, repo_lambda, skip_dirs, dry_run) |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| merge.copy_skipped_dirs(base_model, output_dir, skip_dirs) |
| merge.copy_nontensor_files(base_model, output_dir) |
| merge.save_tensor_map(tensor_map, output_dir) |
|
|
| reset_directories([base_model, staging_model], dry_run) |
| return None, None |
| elif repo_type == "Default": |
| merge.copy_dirs(base_model, output_dir) |
| tensor_map = merge.map_tensors_to_files(base_model) |
|
|
| for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')): |
| repo_name = repo[0].strip() |
| repo_p = repo[1] |
| repo_lambda = repo[2] |
| delete_repo(staging_model, dry_run) |
| download_or_link_repo(repo_name, staging_model, dry_run) |
| tensor_map = do_merge(tensor_map, staging_model, repo_p, repo_lambda, dry_run) |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| merge.copy_nontensor_files(base_model, output_dir) |
|
|
| |
| if os.path.exists(os.path.join(output_dir, 'config.json')): |
| max_vocab_size, repo_name = get_max_vocab_size(repos_to_process) |
| if max_vocab_size > 0: |
| file_paths = ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json'] |
| download_json_files(repo_name, file_paths, output_dir) |
|
|
| reset_directories([base_model, staging_model], dry_run) |
| merge.save_tensor_map(tensor_map, output_dir) |
| return None, None |
|
|
| if __name__ == "__main__": |
| args = parse_arguments() |
| skip_dirs = ['vae', 'text_encoder'] |
| process_repos(args.output_dir, args.base_model, args.staging_model, args.repo_list, args.p, args.lambda_val, skip_dirs, args.dry) |
|
|
|
|