Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| import gc | |
| import shutil | |
| import requests | |
| import json | |
| import struct | |
| import numpy as np | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional, List | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login | |
| from safetensors.torch import load_file, save_file | |
| from tqdm import tqdm | |
| # --- Memory Efficient Safetensors --- | |
| class MemoryEfficientSafeOpen: | |
| """ | |
| Reads safetensors metadata and tensors without mmap, keeping RAM usage low. | |
| """ | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.file = open(filename, "rb") | |
| self.header, self.header_size = self._read_header() | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def keys(self) -> list[str]: | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def metadata(self) -> Dict[str, str]: | |
| return self.header.get("__metadata__", {}) | |
| def get_tensor(self, key): | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| self.file.seek(self.header_size + 8 + offset_start) | |
| tensor_bytes = self.file.read(offset_end - offset_start) | |
| return self._deserialize_tensor(tensor_bytes, metadata) | |
| def _read_header(self): | |
| header_size = struct.unpack("<Q", self.file.read(8))[0] | |
| header_json = self.file.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype_map = { | |
| "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, | |
| "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, | |
| "U8": torch.uint8, "BOOL": torch.bool | |
| } | |
| dtype = dtype_map[metadata["dtype"]] | |
| shape = metadata["shape"] | |
| return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape) | |
| # --- Constants & Setup --- | |
| try: | |
| TempDir = Path("/tmp/temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| except: | |
| TempDir = Path("./temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| api = HfApi() | |
| def cleanup_temp(): | |
| if TempDir.exists(): | |
| shutil.rmtree(TempDir) | |
| os.makedirs(TempDir, exist_ok=True) | |
| gc.collect() | |
| def download_file(input_path, token, filename=None): | |
| local_path = TempDir / (filename if filename else "model.safetensors") | |
| if input_path.startswith("http"): | |
| print(f"Downloading {filename} from URL...") | |
| try: | |
| response = requests.get(input_path, stream=True, timeout=30) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| except Exception as e: raise ValueError(f"Download failed: {e}") | |
| else: | |
| print(f"Downloading {filename} from Hub...") | |
| if not filename: | |
| try: | |
| files = list_repo_files(repo_id=input_path, token=token) | |
| safetensors = [f for f in files if f.endswith(".safetensors")] | |
| filename = safetensors[0] if safetensors else "adapter_model.safetensors" | |
| except: filename = "adapter_model.safetensors" | |
| try: | |
| hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False) | |
| if not (TempDir / filename).exists(): | |
| found = list(TempDir.rglob(filename)) | |
| if found: shutil.move(found[0], local_path) | |
| except Exception as e: raise ValueError(f"Hub download failed: {e}") | |
| return local_path | |
| def get_key_stem(key): | |
| key = key.replace(".weight", "").replace(".bias", "") | |
| key = key.replace(".lora_down", "").replace(".lora_up", "") | |
| key = key.replace(".lora_A", "").replace(".lora_B", "") | |
| key = key.replace(".alpha", "") | |
| prefixes = [ | |
| "model.diffusion_model.", "diffusion_model.", "model.", | |
| "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model." | |
| ] | |
| changed = True | |
| while changed: | |
| changed = False | |
| for p in prefixes: | |
| if key.startswith(p): | |
| key = key[len(p):] | |
| changed = True | |
| return key | |
| # ================================================================================= | |
| # TAB 1: MERGE & RESHARD (Fixes Folder Structure & Aux Files) | |
| # ================================================================================= | |
| def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16): | |
| print(f"Loading LoRA from {lora_path}...") | |
| state_dict = load_file(lora_path, device="cpu") | |
| pairs = {} | |
| alphas = {} | |
| for k, v in state_dict.items(): | |
| stem = get_key_stem(k) | |
| if "alpha" in k: | |
| alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v | |
| else: | |
| if stem not in pairs: pairs[stem] = {} | |
| if "lora_down" in k or "lora_A" in k: | |
| pairs[stem]["down"] = v.to(dtype=precision_dtype) | |
| pairs[stem]["rank"] = v.shape[0] | |
| elif "lora_up" in k or "lora_B" in k: | |
| pairs[stem]["up"] = v.to(dtype=precision_dtype) | |
| for stem in pairs: | |
| pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0))) | |
| return pairs | |
| class ShardBuffer: | |
| def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"): | |
| self.max_bytes = int(max_size_gb * 1024**3) | |
| self.output_dir = output_dir | |
| self.output_repo = output_repo | |
| self.subfolder = subfolder | |
| self.hf_token = hf_token | |
| self.filename_prefix = filename_prefix | |
| self.buffer = [] | |
| self.current_bytes = 0 | |
| self.shard_count = 0 | |
| self.index_map = {} | |
| self.total_size = 0 # Accumulates total model size for index.json | |
| def add_tensor(self, key, tensor): | |
| # Determine bytes for size calculation and storage | |
| if tensor.dtype == torch.bfloat16: | |
| raw_bytes = tensor.view(torch.int16).numpy().tobytes() | |
| dtype_str = "BF16" | |
| elif tensor.dtype == torch.float16: | |
| raw_bytes = tensor.numpy().tobytes() | |
| dtype_str = "F16" | |
| else: | |
| raw_bytes = tensor.numpy().tobytes() | |
| dtype_str = "F32" | |
| size = len(raw_bytes) | |
| self.buffer.append({ | |
| "key": key, | |
| "data": raw_bytes, | |
| "dtype": dtype_str, | |
| "shape": tensor.shape | |
| }) | |
| self.current_bytes += size | |
| self.total_size += size # Explicitly increment total size | |
| if self.current_bytes >= self.max_bytes: | |
| self.flush() | |
| def flush(self): | |
| if not self.buffer: return | |
| self.shard_count += 1 | |
| # Naming: prefix-0000X.safetensors | |
| # This is standard for indexed loading. | |
| filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" | |
| # Proper Subfolder Handling | |
| path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename | |
| print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...") | |
| header = {"__metadata__": {"format": "pt"}} | |
| current_offset = 0 | |
| for item in self.buffer: | |
| header[item["key"]] = { | |
| "dtype": item["dtype"], | |
| "shape": item["shape"], | |
| "data_offsets": [current_offset, current_offset + len(item["data"])] | |
| } | |
| current_offset += len(item["data"]) | |
| self.index_map[item["key"]] = filename # Relative filename for index | |
| header_json = json.dumps(header).encode('utf-8') | |
| out_path = self.output_dir / filename | |
| with open(out_path, 'wb') as f: | |
| f.write(struct.pack('<Q', len(header_json))) | |
| f.write(header_json) | |
| for item in self.buffer: | |
| f.write(item["data"]) | |
| print(f"Uploading {path_in_repo}...") | |
| api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token) | |
| os.remove(out_path) | |
| self.buffer = [] | |
| self.current_bytes = 0 | |
| gc.collect() | |
| def download_lora_smart(input_str, token): | |
| """ | |
| Handles Repo IDs (user/repo) and Direct URLs. | |
| """ | |
| local_path = TempDir / "adapter.safetensors" | |
| # 1. Direct URL (Private/Public) | |
| if input_str.startswith("http"): | |
| print(f"Downloading LoRA from URL: {input_str}") | |
| headers = {"Authorization": f"Bearer {token}"} if token else {} | |
| try: | |
| response = requests.get(input_str, stream=True, headers=headers, timeout=30) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| # Basic validation | |
| with open(local_path, "rb") as f: | |
| if len(f.read(8)) == 8: return local_path | |
| except Exception as e: | |
| print(f"URL download failed: {e}. Trying as Repo ID...") | |
| # 2. Repo ID (Fallback or Primary) | |
| # If the user entered a repo ID (e.g. "AlekseyCalvin/MyLora"), this catches it. | |
| print(f"Attempting download from Hub Repo: {input_str}") | |
| try: | |
| # Try finding the specific file | |
| candidates = ["adapter_model.safetensors", "model.safetensors"] | |
| target_file = None | |
| try: | |
| files = list_repo_files(repo_id=input_str, token=token) | |
| safetensors = [f for f in files if f.endswith(".safetensors")] | |
| for c in candidates: | |
| if c in safetensors: | |
| target_file = c | |
| break | |
| if not target_file and safetensors: | |
| target_file = safetensors[0] | |
| except: | |
| # If listing fails, try default | |
| target_file = "adapter_model.safetensors" | |
| hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir, local_dir_use_symlinks=False) | |
| # Rename to generic name | |
| downloaded = TempDir / target_file | |
| if downloaded != local_path: | |
| if local_path.exists(): os.remove(local_path) | |
| shutil.move(downloaded, local_path) | |
| return local_path | |
| except Exception as e: | |
| raise ValueError(f"Failed to download LoRA from {input_str}. \nError: {e}") | |
| def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()): | |
| cleanup_temp() | |
| login(hf_token) | |
| # 1. Output Setup | |
| try: | |
| api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token) | |
| except Exception as e: return f"Error creating repo: {e}" | |
| # Define modes | |
| output_subfolder = base_subfolder if base_subfolder else "" | |
| # 2. Clone Structure | |
| if structure_repo: | |
| print(f"Cloning structure from {structure_repo}...") | |
| # Ignore the folder we are overwriting (if any) | |
| ignore = output_subfolder if output_subfolder else None | |
| # Root merge mode (LLM) usually implies we skip weights in the root | |
| is_root_merge = not bool(output_subfolder) | |
| streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=is_root_merge) | |
| # 3. Download Input Shards | |
| progress(0.1, desc="Downloading Base Model...") | |
| try: | |
| files = list_repo_files(repo_id=base_repo, token=hf_token) | |
| except Exception as e: return f"Error accessing base repo: {e}" | |
| input_shards = [] | |
| for f in files: | |
| if f.endswith(".safetensors"): | |
| # Filter by subfolder if specified | |
| if output_subfolder and not f.startswith(output_subfolder): continue | |
| local_path = TempDir / "input_shards" / os.path.basename(f) | |
| os.makedirs(local_path.parent, exist_ok=True) | |
| hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local_path.parent, local_dir_use_symlinks=False) | |
| # Locate file (handle nested download paths) | |
| found = list(local_path.parent.rglob(os.path.basename(f))) | |
| if found: input_shards.append(found[0]) | |
| if not input_shards: return "No base safetensors found in specified location." | |
| input_shards.sort() | |
| # --- NAMING CONVENTION LOGIC --- | |
| # 1. Check for Diffusers specific subfolders -> force 'diffusion_pytorch_model' | |
| if output_subfolder in ["transformer", "unet"]: | |
| filename_prefix = "diffusion_pytorch_model" | |
| index_filename = "diffusion_pytorch_model.safetensors.index.json" | |
| # 2. Check input file naming -> adopt input convention | |
| elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]): | |
| filename_prefix = "diffusion_pytorch_model" | |
| index_filename = "diffusion_pytorch_model.safetensors.index.json" | |
| # 3. Default to LLM style | |
| else: | |
| filename_prefix = "model" | |
| index_filename = "model.safetensors.index.json" | |
| print(f"Naming scheme: {filename_prefix} (Index: {index_filename})") | |
| # 4. Load LoRA | |
| dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32 | |
| try: | |
| progress(0.15, desc="Downloading LoRA...") | |
| lora_path = download_lora_smart(lora_input, hf_token) | |
| lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype) | |
| except Exception as e: return f"Error loading LoRA: {e}" | |
| # 5. Stream Process | |
| buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix) | |
| for i, shard_file in enumerate(input_shards): | |
| progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}") | |
| with MemoryEfficientSafeOpen(shard_file) as f: | |
| keys = f.keys() | |
| for k in keys: | |
| v = f.get_tensor(k) | |
| base_stem = get_key_stem(k) | |
| lora_keys = set(lora_pairs.keys()) | |
| match = None | |
| if base_stem in lora_keys: match = lora_pairs[base_stem] | |
| # QKV Heuristics (Z-Image/Flux specific) | |
| if not match: | |
| if "to_q" in base_stem: | |
| qkv_stem = base_stem.replace("to_q", "qkv") | |
| if qkv_stem in lora_keys: match = lora_pairs[qkv_stem] | |
| elif "to_k" in base_stem: | |
| qkv_stem = base_stem.replace("to_k", "qkv") | |
| if qkv_stem in lora_keys: match = lora_pairs[qkv_stem] | |
| elif "to_v" in base_stem: | |
| qkv_stem = base_stem.replace("to_v", "qkv") | |
| if qkv_stem in lora_keys: match = lora_pairs[qkv_stem] | |
| if match and "down" in match and "up" in match: | |
| down = match["down"] | |
| up = match["up"] | |
| scaling = scale * (match["alpha"] / match["rank"]) | |
| if len(v.shape) == 4 and len(down.shape) == 2: | |
| down = down.unsqueeze(-1).unsqueeze(-1) | |
| up = up.unsqueeze(-1).unsqueeze(-1) | |
| try: | |
| if len(up.shape) == 4: | |
| delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) | |
| else: | |
| delta = up @ down | |
| except: | |
| delta = up.T @ down | |
| delta = delta * scaling | |
| valid_delta = True | |
| if delta.shape == v.shape: pass | |
| elif delta.shape[0] == v.shape[0] * 3: | |
| chunk = v.shape[0] | |
| if "to_q" in k: delta = delta[0:chunk, ...] | |
| elif "to_k" in k: delta = delta[chunk:2*chunk, ...] | |
| elif "to_v" in k: delta = delta[2*chunk:, ...] | |
| else: valid_delta = False | |
| elif delta.numel() == v.numel(): | |
| delta = delta.reshape(v.shape) | |
| else: valid_delta = False | |
| if valid_delta: | |
| v = v.to(dtype) | |
| delta = delta.to(dtype) | |
| v.add_(delta) | |
| del delta | |
| if v.dtype != dtype: v = v.to(dtype) | |
| buffer.add_tensor(k, v) | |
| del v | |
| os.remove(shard_file) | |
| gc.collect() | |
| buffer.flush() | |
| # 6. Upload Index (Now using correct total_size) | |
| print(f"Uploading Index: {index_filename} (Total Size: {buffer.total_size})") | |
| index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map} | |
| with open(TempDir / index_filename, "w") as f: | |
| json.dump(index_data, f, indent=4) | |
| path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename | |
| api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token) | |
| cleanup_temp() | |
| return f"Done! Merged into {buffer.shard_count} shards at {output_repo}" | |
| # ================================================================================= | |
| # TAB 2: EXTRACT LORA | |
| # ================================================================================= | |
| def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp): | |
| org = MemoryEfficientSafeOpen(model_org) | |
| tuned = MemoryEfficientSafeOpen(model_tuned) | |
| lora_sd = {} | |
| print("Calculating diffs...") | |
| for key in tqdm(org.keys()): | |
| if key not in tuned.keys(): continue | |
| mat_org = org.get_tensor(key).float() | |
| mat_tuned = tuned.get_tensor(key).float() | |
| diff = mat_tuned - mat_org | |
| if torch.max(torch.abs(diff)) < 1e-4: continue | |
| out_dim, in_dim = diff.shape[:2] | |
| r = min(rank, in_dim, out_dim) | |
| is_conv = len(diff.shape) == 4 | |
| if is_conv: diff = diff.flatten(start_dim=1) | |
| try: | |
| U, S, Vh = torch.linalg.svd(diff, full_matrices=False) | |
| U, S, Vh = U[:, :r], S[:r], Vh[:r, :] | |
| U = U @ torch.diag(S) | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| hi_val = torch.quantile(dist, clamp) | |
| U = U.clamp(-hi_val, hi_val) | |
| Vh = Vh.clamp(-hi_val, hi_val) | |
| if is_conv: | |
| U = U.reshape(out_dim, r, 1, 1) | |
| Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3]) | |
| else: | |
| U = U.reshape(out_dim, r) | |
| Vh = Vh.reshape(r, in_dim) | |
| stem = key.replace(".weight", "") | |
| lora_sd[f"{stem}.lora_up.weight"] = U | |
| lora_sd[f"{stem}.lora_down.weight"] = Vh | |
| lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() | |
| except: pass | |
| out = TempDir / "extracted.safetensors" | |
| save_file(lora_sd, out) | |
| return str(out) | |
| def task_extract(hf_token, org, tun, rank, out): | |
| cleanup_temp() | |
| login(hf_token) | |
| try: | |
| p1 = download_file(org, hf_token, filename="org.safetensors") | |
| p2 = download_file(tun, hf_token, filename="tun.safetensors") | |
| f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99) | |
| api.create_repo(repo_id=out, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token) | |
| return "Done" | |
| except Exception as e: return f"Error: {e}" | |
| # ================================================================================= | |
| # TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel | |
| # ================================================================================= | |
| def sigma_rel_to_gamma(sigma_rel): | |
| t = sigma_rel**-2 | |
| coeffs = [1, 7, 16 - t, 12 - t] | |
| roots = np.roots(coeffs) | |
| gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() | |
| return gamma | |
| def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo): | |
| cleanup_temp() | |
| login(hf_token) | |
| urls = [u.strip() for u in lora_urls.split(",") if u.strip()] | |
| paths = [] | |
| try: | |
| for i, url in enumerate(urls): | |
| paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors")) | |
| except Exception as e: return f"Download Error: {e}" | |
| if not paths: return "No models found" | |
| base_sd = load_file(paths[0], device="cpu") | |
| for k in base_sd: | |
| if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float() | |
| gamma = None | |
| if sigma_rel > 0: | |
| gamma = sigma_rel_to_gamma(sigma_rel) | |
| for i, path in enumerate(paths[1:]): | |
| print(f"Merging {path}") | |
| if gamma is not None: | |
| t = i + 1 | |
| current_beta = (1 - 1 / t) ** (gamma + 1) | |
| else: | |
| current_beta = beta | |
| curr = load_file(path, device="cpu") | |
| for k in base_sd: | |
| if k in curr and "alpha" not in k: | |
| base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta) | |
| out = TempDir / "merged_adapters.safetensors" | |
| save_file(base_sd, out) | |
| api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) | |
| return "Done" | |
| # ================================================================================= | |
| # TAB 4: RESIZE | |
| # ================================================================================= | |
| def index_sv_ratio(S, target): | |
| max_sv = S[0] | |
| min_sv = max_sv / target | |
| index = int(torch.sum(S > min_sv).item()) | |
| return max(1, min(index, len(S) - 1)) | |
| def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): | |
| cleanup_temp() | |
| login(hf_token) | |
| try: | |
| path = download_file(lora_input, hf_token) | |
| except Exception as e: return f"Error: {e}" | |
| state = load_file(path, device="cpu") | |
| new_state = {} | |
| groups = {} | |
| for k in state: | |
| stem = get_key_stem(k) | |
| simple = k.split(".lora_")[0] | |
| if simple not in groups: groups[simple] = {} | |
| if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] | |
| if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] | |
| if "alpha" in k: groups[simple]["alpha"] = state[k] | |
| for stem, g in tqdm(groups.items()): | |
| if "down" in g and "up" in g: | |
| down, up = g["down"].float(), g["up"].float() | |
| if len(down.shape) == 4: | |
| merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) | |
| flat = merged.flatten(1) | |
| else: | |
| merged = up @ down | |
| flat = merged | |
| U, S, Vh = torch.linalg.svd(flat, full_matrices=False) | |
| target_rank = int(new_rank) | |
| if dynamic_method == "sv_ratio": | |
| target_rank = index_sv_ratio(S, dynamic_param) | |
| target_rank = min(target_rank, S.shape[0]) | |
| U = U[:, :target_rank] | |
| S = S[:target_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:target_rank, :] | |
| if len(down.shape) == 4: | |
| U = U.reshape(up.shape[0], target_rank, 1, 1) | |
| Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3]) | |
| new_state[f"{stem}.lora_down.weight"] = Vh | |
| new_state[f"{stem}.lora_up.weight"] = U | |
| new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float() | |
| out = TempDir / "resized.safetensors" | |
| save_file(new_state, out) | |
| api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token) | |
| return "Done" | |
| # ================================================================================= | |
| # UI | |
| # ================================================================================= | |
| css = ".container { max-width: 900px; margin: auto; }" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧰SOONmerge® LoRA Toolkit") | |
| with gr.Tabs(): | |
| with gr.Tab("Merge to Base + Reshard Output"): | |
| t1_token = gr.Textbox(label="Token", type="password") | |
| t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo") | |
| t1_sub = gr.Textbox(label="Subfolder", value="transformer") | |
| t1_lora = gr.Textbox(label="LoRA Direct Link", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors") | |
| with gr.Row(): | |
| t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) | |
| t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") | |
| t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) | |
| t1_out = gr.Textbox(label="Output Repo") | |
| t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo") | |
| t1_priv = gr.Checkbox(label="Private", value=True) | |
| t1_btn = gr.Button("Merge") | |
| t1_res = gr.Textbox(label="Result") | |
| t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res) | |
| with gr.Tab("Extract Adapter"): | |
| t2_token = gr.Textbox(label="Token", type="password") | |
| t2_org = gr.Textbox(label="Original Model") | |
| t2_tun = gr.Textbox(label="Tuned Model") | |
| t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) | |
| t2_out = gr.Textbox(label="Output Repo") | |
| t2_btn = gr.Button("Extract") | |
| t2_res = gr.Textbox(label="Result") | |
| t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res) | |
| with gr.Tab("Merge Multiple Adapters"): | |
| t3_token = gr.Textbox(label="Token", type="password") | |
| t3_urls = gr.Textbox(label="URLs") | |
| with gr.Row(): | |
| t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01) | |
| t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01) | |
| t3_out = gr.Textbox(label="Output Repo") | |
| t3_btn = gr.Button("Merge") | |
| t3_res = gr.Textbox(label="Result") | |
| t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res) | |
| with gr.Tab("Resize Adapter"): | |
| t4_token = gr.Textbox(label="Token", type="password") | |
| t4_in = gr.Textbox(label="LoRA") | |
| with gr.Row(): | |
| t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1) | |
| t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method") | |
| t4_param = gr.Number(label="Dynamic Param", value=4.0) | |
| t4_out = gr.Textbox(label="Output") | |
| t4_btn = gr.Button("Resize") | |
| t4_res = gr.Textbox(label="Result") | |
| t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res) | |
| if __name__ == "__main__": | |
| demo.queue().launch(css=css, ssr_mode=False) |