File size: 13,845 Bytes
58aaafc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import gradio as gr
import torch
import os
import gc
import json
import shutil
import requests
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm

# --- Constants & Setup ---
TempDir = Path("./temp_merge")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()

def info_log(msg, progress=None):
    print(msg)
    if progress:
        return msg
    return msg

def cleanup_temp():
    if TempDir.exists():
        shutil.rmtree(TempDir)
    os.makedirs(TempDir, exist_ok=True)
    gc.collect()

# --- Core Logic ---

def download_lora(lora_input, hf_token):
    """Downloads LoRA from a Repo ID or a direct URL."""
    local_path = TempDir / "adapter.safetensors"
    
    if lora_input.startswith("http"):
        # Direct URL download
        print(f"Downloading LoRA from URL: {lora_input}")
        response = requests.get(lora_input, stream=True)
        response.raise_for_status()
        with open(local_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        return local_path
    else:
        # Repo ID download
        print(f"Downloading LoRA from Repo: {lora_input}")
        # Try finding the safetensors file
        try:
            return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
        except:
            # Fallback for diffusion models which might use different names
            files = list_repo_files(repo_id=lora_input, token=hf_token)
            safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
            if not safe_files:
                # Last ditch: grab the first safetensors
                safe_files = [f for f in files if f.endswith(".safetensors")]
            
            if not safe_files:
                raise ValueError("Could not find a .safetensors file in the LoRA repo.")
            
            return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)

def load_lora_weights(path):
    """Loads LoRA weights and attempts to determine rank/alpha."""
    tensors = load_file(path, device="cpu")
    # Basic metadata extraction could happen here if needed, 
    # but for raw merging we mainly need the state dict.
    return tensors

def match_keys(base_key, lora_keys):
    """
    Heuristic matching. 
    1. Exact match (rare for LoRA).
    2. LoRA naming conventions (lora_A, lora_B, lora_down, etc).
    """
    # Common LoRA naming patterns
    # pattern: base_key.lora_A.weight
    # pattern: base_key + ".0.lora_B.weight" (sometimes happens)
    
    matches = {}
    
    # Cleaning the keys for comparison
    # If base is "transformer.blocks.0.weight"
    # LoRA might be "transformer.blocks.0.lora_A.weight"
    
    candidates = [k for k in lora_keys if base_key in k]
    
    pair_A = None
    pair_B = None
    
    for k in candidates:
        if "lora_A" in k or "lora_down" in k:
            pair_A = k
        elif "lora_B" in k or "lora_up" in k:
            pair_B = k
            
    return pair_A, pair_B

def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
    """Copies config/tokenizer/scheduler files from source to target."""
    print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
    files = list_repo_files(repo_id=src_repo, token=token)
    
    # Filter out heavy weights
    files_to_copy = [
        f for f in files 
        if not f.endswith(".safetensors") 
        and not f.endswith(".bin") 
        and not f.endswith(".pt")
        and not f.endswith(".pth")
        and not f.endswith(".msgpack")
        and not f.endswith(".h5")
    ]

    for f in tqdm(files_to_copy, desc="Copying configs"):
        try:
            # We download to memory/temp and upload immediately
            local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
            api.upload_file(
                path_or_fileobj=local,
                path_in_repo=f,
                repo_id=tgt_repo,
                repo_type="model",
                token=token
            )
            os.remove(local)
        except Exception as e:
            print(f"Skipped {f}: {e}")

def run_merge(
    hf_token, 
    base_repo, 
    base_subfolder,
    structure_repo,
    lora_input, 
    scale, 
    output_repo, 
    is_private,
    progress=gr.Progress()
):
    cleanup_temp()
    logs = []
    
    try:
        login(hf_token)
        logs.append(f"Logged in. Target: {output_repo}")
        
        # 1. Create Output Repo
        try:
            api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
            logs.append("Output repository ready.")
        except Exception as e:
            return "\n".join(logs) + f"\nError creating repo: {e}"

        # 2. Replicate Structure (If requested)
        if structure_repo.strip():
            progress(0.1, desc="Cloning Model Structure (Configs)...")
            logs.append(f"Cloning configuration from {structure_repo}...")
            copy_auxiliary_files(structure_repo, output_repo, hf_token)
            logs.append("Configuration files copied.")

        # 3. Load LoRA
        progress(0.2, desc="Downloading LoRA...")
        logs.append(f"Fetching LoRA: {lora_input}")
        lora_path = download_lora(lora_input, hf_token)
        lora_state = load_lora_weights(lora_path)
        lora_keys = list(lora_state.keys())
        logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.")

        # 4. Identify Base Shards
        progress(0.3, desc="Analyzing Base Model...")
        all_files = list_repo_files(repo_id=base_repo, token=hf_token)
        
        # Filter for safetensors in the specific subfolder (if provided)
        target_shards = []
        for f in all_files:
            if not f.endswith(".safetensors"):
                continue
            
            # Check subfolder constraint
            if base_subfolder.strip():
                # Normalize paths
                if not f.startswith(base_subfolder.strip("/")):
                    continue
            
            target_shards.append(f)
            
        logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
        if not target_shards:
            raise ValueError("No safetensors found in the specified base repo/subfolder.")

        # 5. Process Shards (Streamed)
        total_shards = len(target_shards)
        merged_count = 0
        
        for idx, shard_file in enumerate(target_shards):
            progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
            logs.append(f"--- Processing {shard_file} ---")
            
            # Download Shard
            local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
            
            # Load and Merge
            # We use safe_open to read metadata, but load_file for the dict to modify
            # load_file loads to CPU RAM.
            base_tensors = load_file(local_shard, device="cpu")
            modified_tensors = {}
            has_changes = False
            
            for key, tensor in base_tensors.items():
                # Match LoRA
                # Handle architectural prefix mismatches (e.g. Ostris repo might rely on folder structure, 
                # while LoRA expects "transformer." prefix)
                
                # Try exact match first (unlikely for LoRA)
                pair_A, pair_B = match_keys(key, lora_keys)
                
                # If not found, try adding/removing common prefixes
                if not pair_A:
                     # Attempt to match "blocks.1..." to "transformer.blocks.1..."
                     matches = [k for k in lora_keys if key in k] # Simple substring check
                     for k in matches:
                        if "lora_A" in k or "lora_down" in k:
                            pair_A = k
                        elif "lora_B" in k or "lora_up" in k:
                            pair_B = k

                if pair_A and pair_B:
                    # Apply Merge
                    w_a = lora_state[pair_A].float()
                    w_b = lora_state[pair_B].float()
                    
                    # Target tensor
                    current_tensor = tensor.float()
                    
                    # Dimension Check
                    # LoRA = B @ A. Shape should match current_tensor.
                    # Sometimes LoRA weights are transposed relative to base depending on training lib.
                    delta = (w_b @ w_a) * scale
                    
                    if delta.shape != current_tensor.shape:
                        # Try transposing matches
                        if delta.T.shape == current_tensor.shape:
                            delta = delta.T
                        else:
                            logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, LoRA Delta: {delta.shape}. Skipping.")
                            modified_tensors[key] = tensor
                            continue
                            
                    modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
                    merged_count += 1
                    has_changes = True
                else:
                    modified_tensors[key] = tensor

            # Save and Upload
            if has_changes:
                logs.append(f"Merging complete for shard. Saving...")
                output_path = TempDir / "processed.safetensors"
                save_file(modified_tensors, output_path)
                
                api.upload_file(
                    path_or_fileobj=output_path,
                    path_in_repo=shard_file, # Keep original structure
                    repo_id=output_repo,
                    repo_type="model",
                    token=hf_token
                )
                logs.append(f"Uploaded {shard_file}")
            else:
                # If no changes, just copy the original file to the new repo
                # This saves re-saving the tensor dict
                logs.append(f"No LoRA matches in this shard. Copying original...")
                api.upload_file(
                    path_or_fileobj=local_shard,
                    path_in_repo=shard_file,
                    repo_id=output_repo,
                    repo_type="model",
                    token=hf_token
                )
            
            # Cleanup Memory immediately
            del base_tensors
            del modified_tensors
            if 'delta' in locals(): del delta
            gc.collect()
            os.remove(local_shard)
            if os.path.exists(TempDir / "processed.safetensors"):
                os.remove(TempDir / "processed.safetensors")

        progress(1.0, desc="Done!")
        logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
        logs.append(f"New model available at: https://huggingface.co/{output_repo}")
        
    except Exception as e:
        import traceback
        logs.append(f"\nCRITICAL ERROR: {str(e)}")
        logs.append(traceback.format_exc())
    
    finally:
        cleanup_temp()
        
    return "\n".join(logs)


# --- UI ---

css = """
.container { max-width: 900px; margin: auto; }
.header { text-align: center; margin-bottom: 20px; }
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown(
        """
        # ⚡ Universal LoRA Merger & Reconstructor
        
        Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
        Optimized for CPU-only execution on Hugging Face Spaces.
        """
    )
    
    with gr.Group():
        gr.Markdown("### 1. Authentication & Output")
        with gr.Row():
            hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
            output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom")
            is_private = gr.Checkbox(label="Private Repo", value=True)

    with gr.Group():
        gr.Markdown("### 2. Base Weights (The Target)")
        with gr.Row():
            base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
            base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")

    with gr.Group():
        gr.Markdown("### 3. LoRA Configuration")
        with gr.Row():
            lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
            scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1)

    with gr.Group():
        gr.Markdown("### 4. Repository Reconstruction (Optional)")
        gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
        structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")

    submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
    
    output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)

    submit_btn.click(
        fn=run_merge,
        inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
        outputs=output_log
    )

if __name__ == "__main__":
    demo.queue(max_size=1).launch()