Spaces:
Running
Running
File size: 13,845 Bytes
58aaafc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
import gradio as gr
import torch
import os
import gc
import json
import shutil
import requests
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
# --- Constants & Setup ---
TempDir = Path("./temp_merge")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def info_log(msg, progress=None):
print(msg)
if progress:
return msg
return msg
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
# --- Core Logic ---
def download_lora(lora_input, hf_token):
"""Downloads LoRA from a Repo ID or a direct URL."""
local_path = TempDir / "adapter.safetensors"
if lora_input.startswith("http"):
# Direct URL download
print(f"Downloading LoRA from URL: {lora_input}")
response = requests.get(lora_input, stream=True)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
else:
# Repo ID download
print(f"Downloading LoRA from Repo: {lora_input}")
# Try finding the safetensors file
try:
return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
except:
# Fallback for diffusion models which might use different names
files = list_repo_files(repo_id=lora_input, token=hf_token)
safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
if not safe_files:
# Last ditch: grab the first safetensors
safe_files = [f for f in files if f.endswith(".safetensors")]
if not safe_files:
raise ValueError("Could not find a .safetensors file in the LoRA repo.")
return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
def load_lora_weights(path):
"""Loads LoRA weights and attempts to determine rank/alpha."""
tensors = load_file(path, device="cpu")
# Basic metadata extraction could happen here if needed,
# but for raw merging we mainly need the state dict.
return tensors
def match_keys(base_key, lora_keys):
"""
Heuristic matching.
1. Exact match (rare for LoRA).
2. LoRA naming conventions (lora_A, lora_B, lora_down, etc).
"""
# Common LoRA naming patterns
# pattern: base_key.lora_A.weight
# pattern: base_key + ".0.lora_B.weight" (sometimes happens)
matches = {}
# Cleaning the keys for comparison
# If base is "transformer.blocks.0.weight"
# LoRA might be "transformer.blocks.0.lora_A.weight"
candidates = [k for k in lora_keys if base_key in k]
pair_A = None
pair_B = None
for k in candidates:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
return pair_A, pair_B
def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
"""Copies config/tokenizer/scheduler files from source to target."""
print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
files = list_repo_files(repo_id=src_repo, token=token)
# Filter out heavy weights
files_to_copy = [
f for f in files
if not f.endswith(".safetensors")
and not f.endswith(".bin")
and not f.endswith(".pt")
and not f.endswith(".pth")
and not f.endswith(".msgpack")
and not f.endswith(".h5")
]
for f in tqdm(files_to_copy, desc="Copying configs"):
try:
# We download to memory/temp and upload immediately
local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
api.upload_file(
path_or_fileobj=local,
path_in_repo=f,
repo_id=tgt_repo,
repo_type="model",
token=token
)
os.remove(local)
except Exception as e:
print(f"Skipped {f}: {e}")
def run_merge(
hf_token,
base_repo,
base_subfolder,
structure_repo,
lora_input,
scale,
output_repo,
is_private,
progress=gr.Progress()
):
cleanup_temp()
logs = []
try:
login(hf_token)
logs.append(f"Logged in. Target: {output_repo}")
# 1. Create Output Repo
try:
api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
logs.append("Output repository ready.")
except Exception as e:
return "\n".join(logs) + f"\nError creating repo: {e}"
# 2. Replicate Structure (If requested)
if structure_repo.strip():
progress(0.1, desc="Cloning Model Structure (Configs)...")
logs.append(f"Cloning configuration from {structure_repo}...")
copy_auxiliary_files(structure_repo, output_repo, hf_token)
logs.append("Configuration files copied.")
# 3. Load LoRA
progress(0.2, desc="Downloading LoRA...")
logs.append(f"Fetching LoRA: {lora_input}")
lora_path = download_lora(lora_input, hf_token)
lora_state = load_lora_weights(lora_path)
lora_keys = list(lora_state.keys())
logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.")
# 4. Identify Base Shards
progress(0.3, desc="Analyzing Base Model...")
all_files = list_repo_files(repo_id=base_repo, token=hf_token)
# Filter for safetensors in the specific subfolder (if provided)
target_shards = []
for f in all_files:
if not f.endswith(".safetensors"):
continue
# Check subfolder constraint
if base_subfolder.strip():
# Normalize paths
if not f.startswith(base_subfolder.strip("/")):
continue
target_shards.append(f)
logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
if not target_shards:
raise ValueError("No safetensors found in the specified base repo/subfolder.")
# 5. Process Shards (Streamed)
total_shards = len(target_shards)
merged_count = 0
for idx, shard_file in enumerate(target_shards):
progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
logs.append(f"--- Processing {shard_file} ---")
# Download Shard
local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
# Load and Merge
# We use safe_open to read metadata, but load_file for the dict to modify
# load_file loads to CPU RAM.
base_tensors = load_file(local_shard, device="cpu")
modified_tensors = {}
has_changes = False
for key, tensor in base_tensors.items():
# Match LoRA
# Handle architectural prefix mismatches (e.g. Ostris repo might rely on folder structure,
# while LoRA expects "transformer." prefix)
# Try exact match first (unlikely for LoRA)
pair_A, pair_B = match_keys(key, lora_keys)
# If not found, try adding/removing common prefixes
if not pair_A:
# Attempt to match "blocks.1..." to "transformer.blocks.1..."
matches = [k for k in lora_keys if key in k] # Simple substring check
for k in matches:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
if pair_A and pair_B:
# Apply Merge
w_a = lora_state[pair_A].float()
w_b = lora_state[pair_B].float()
# Target tensor
current_tensor = tensor.float()
# Dimension Check
# LoRA = B @ A. Shape should match current_tensor.
# Sometimes LoRA weights are transposed relative to base depending on training lib.
delta = (w_b @ w_a) * scale
if delta.shape != current_tensor.shape:
# Try transposing matches
if delta.T.shape == current_tensor.shape:
delta = delta.T
else:
logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, LoRA Delta: {delta.shape}. Skipping.")
modified_tensors[key] = tensor
continue
modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
merged_count += 1
has_changes = True
else:
modified_tensors[key] = tensor
# Save and Upload
if has_changes:
logs.append(f"Merging complete for shard. Saving...")
output_path = TempDir / "processed.safetensors"
save_file(modified_tensors, output_path)
api.upload_file(
path_or_fileobj=output_path,
path_in_repo=shard_file, # Keep original structure
repo_id=output_repo,
repo_type="model",
token=hf_token
)
logs.append(f"Uploaded {shard_file}")
else:
# If no changes, just copy the original file to the new repo
# This saves re-saving the tensor dict
logs.append(f"No LoRA matches in this shard. Copying original...")
api.upload_file(
path_or_fileobj=local_shard,
path_in_repo=shard_file,
repo_id=output_repo,
repo_type="model",
token=hf_token
)
# Cleanup Memory immediately
del base_tensors
del modified_tensors
if 'delta' in locals(): del delta
gc.collect()
os.remove(local_shard)
if os.path.exists(TempDir / "processed.safetensors"):
os.remove(TempDir / "processed.safetensors")
progress(1.0, desc="Done!")
logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
logs.append(f"New model available at: https://huggingface.co/{output_repo}")
except Exception as e:
import traceback
logs.append(f"\nCRITICAL ERROR: {str(e)}")
logs.append(traceback.format_exc())
finally:
cleanup_temp()
return "\n".join(logs)
# --- UI ---
css = """
.container { max-width: 900px; margin: auto; }
.header { text-align: center; margin-bottom: 20px; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown(
"""
# ⚡ Universal LoRA Merger & Reconstructor
Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
Optimized for CPU-only execution on Hugging Face Spaces.
"""
)
with gr.Group():
gr.Markdown("### 1. Authentication & Output")
with gr.Row():
hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom")
is_private = gr.Checkbox(label="Private Repo", value=True)
with gr.Group():
gr.Markdown("### 2. Base Weights (The Target)")
with gr.Row():
base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")
with gr.Group():
gr.Markdown("### 3. LoRA Configuration")
with gr.Row():
lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1)
with gr.Group():
gr.Markdown("### 4. Repository Reconstruction (Optional)")
gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")
submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)
submit_btn.click(
fn=run_merge,
inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
outputs=output_log
)
if __name__ == "__main__":
demo.queue(max_size=1).launch() |