AlekseyCalvin commited on
Commit
58aaafc
·
verified ·
1 Parent(s): 51d846a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -0
app.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import json
6
+ import shutil
7
+ import requests
8
+ from pathlib import Path
9
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
10
+ from safetensors.torch import load_file, save_file
11
+ from safetensors import safe_open
12
+ from tqdm import tqdm
13
+
14
+ # --- Constants & Setup ---
15
+ TempDir = Path("./temp_merge")
16
+ os.makedirs(TempDir, exist_ok=True)
17
+ api = HfApi()
18
+
19
+ def info_log(msg, progress=None):
20
+ print(msg)
21
+ if progress:
22
+ return msg
23
+ return msg
24
+
25
+ def cleanup_temp():
26
+ if TempDir.exists():
27
+ shutil.rmtree(TempDir)
28
+ os.makedirs(TempDir, exist_ok=True)
29
+ gc.collect()
30
+
31
+ # --- Core Logic ---
32
+
33
+ def download_lora(lora_input, hf_token):
34
+ """Downloads LoRA from a Repo ID or a direct URL."""
35
+ local_path = TempDir / "adapter.safetensors"
36
+
37
+ if lora_input.startswith("http"):
38
+ # Direct URL download
39
+ print(f"Downloading LoRA from URL: {lora_input}")
40
+ response = requests.get(lora_input, stream=True)
41
+ response.raise_for_status()
42
+ with open(local_path, 'wb') as f:
43
+ for chunk in response.iter_content(chunk_size=8192):
44
+ f.write(chunk)
45
+ return local_path
46
+ else:
47
+ # Repo ID download
48
+ print(f"Downloading LoRA from Repo: {lora_input}")
49
+ # Try finding the safetensors file
50
+ try:
51
+ return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
52
+ except:
53
+ # Fallback for diffusion models which might use different names
54
+ files = list_repo_files(repo_id=lora_input, token=hf_token)
55
+ safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
56
+ if not safe_files:
57
+ # Last ditch: grab the first safetensors
58
+ safe_files = [f for f in files if f.endswith(".safetensors")]
59
+
60
+ if not safe_files:
61
+ raise ValueError("Could not find a .safetensors file in the LoRA repo.")
62
+
63
+ return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
64
+
65
+ def load_lora_weights(path):
66
+ """Loads LoRA weights and attempts to determine rank/alpha."""
67
+ tensors = load_file(path, device="cpu")
68
+ # Basic metadata extraction could happen here if needed,
69
+ # but for raw merging we mainly need the state dict.
70
+ return tensors
71
+
72
+ def match_keys(base_key, lora_keys):
73
+ """
74
+ Heuristic matching.
75
+ 1. Exact match (rare for LoRA).
76
+ 2. LoRA naming conventions (lora_A, lora_B, lora_down, etc).
77
+ """
78
+ # Common LoRA naming patterns
79
+ # pattern: base_key.lora_A.weight
80
+ # pattern: base_key + ".0.lora_B.weight" (sometimes happens)
81
+
82
+ matches = {}
83
+
84
+ # Cleaning the keys for comparison
85
+ # If base is "transformer.blocks.0.weight"
86
+ # LoRA might be "transformer.blocks.0.lora_A.weight"
87
+
88
+ candidates = [k for k in lora_keys if base_key in k]
89
+
90
+ pair_A = None
91
+ pair_B = None
92
+
93
+ for k in candidates:
94
+ if "lora_A" in k or "lora_down" in k:
95
+ pair_A = k
96
+ elif "lora_B" in k or "lora_up" in k:
97
+ pair_B = k
98
+
99
+ return pair_A, pair_B
100
+
101
+ def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
102
+ """Copies config/tokenizer/scheduler files from source to target."""
103
+ print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
104
+ files = list_repo_files(repo_id=src_repo, token=token)
105
+
106
+ # Filter out heavy weights
107
+ files_to_copy = [
108
+ f for f in files
109
+ if not f.endswith(".safetensors")
110
+ and not f.endswith(".bin")
111
+ and not f.endswith(".pt")
112
+ and not f.endswith(".pth")
113
+ and not f.endswith(".msgpack")
114
+ and not f.endswith(".h5")
115
+ ]
116
+
117
+ for f in tqdm(files_to_copy, desc="Copying configs"):
118
+ try:
119
+ # We download to memory/temp and upload immediately
120
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
121
+ api.upload_file(
122
+ path_or_fileobj=local,
123
+ path_in_repo=f,
124
+ repo_id=tgt_repo,
125
+ repo_type="model",
126
+ token=token
127
+ )
128
+ os.remove(local)
129
+ except Exception as e:
130
+ print(f"Skipped {f}: {e}")
131
+
132
+ def run_merge(
133
+ hf_token,
134
+ base_repo,
135
+ base_subfolder,
136
+ structure_repo,
137
+ lora_input,
138
+ scale,
139
+ output_repo,
140
+ is_private,
141
+ progress=gr.Progress()
142
+ ):
143
+ cleanup_temp()
144
+ logs = []
145
+
146
+ try:
147
+ login(hf_token)
148
+ logs.append(f"Logged in. Target: {output_repo}")
149
+
150
+ # 1. Create Output Repo
151
+ try:
152
+ api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
153
+ logs.append("Output repository ready.")
154
+ except Exception as e:
155
+ return "\n".join(logs) + f"\nError creating repo: {e}"
156
+
157
+ # 2. Replicate Structure (If requested)
158
+ if structure_repo.strip():
159
+ progress(0.1, desc="Cloning Model Structure (Configs)...")
160
+ logs.append(f"Cloning configuration from {structure_repo}...")
161
+ copy_auxiliary_files(structure_repo, output_repo, hf_token)
162
+ logs.append("Configuration files copied.")
163
+
164
+ # 3. Load LoRA
165
+ progress(0.2, desc="Downloading LoRA...")
166
+ logs.append(f"Fetching LoRA: {lora_input}")
167
+ lora_path = download_lora(lora_input, hf_token)
168
+ lora_state = load_lora_weights(lora_path)
169
+ lora_keys = list(lora_state.keys())
170
+ logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.")
171
+
172
+ # 4. Identify Base Shards
173
+ progress(0.3, desc="Analyzing Base Model...")
174
+ all_files = list_repo_files(repo_id=base_repo, token=hf_token)
175
+
176
+ # Filter for safetensors in the specific subfolder (if provided)
177
+ target_shards = []
178
+ for f in all_files:
179
+ if not f.endswith(".safetensors"):
180
+ continue
181
+
182
+ # Check subfolder constraint
183
+ if base_subfolder.strip():
184
+ # Normalize paths
185
+ if not f.startswith(base_subfolder.strip("/")):
186
+ continue
187
+
188
+ target_shards.append(f)
189
+
190
+ logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
191
+ if not target_shards:
192
+ raise ValueError("No safetensors found in the specified base repo/subfolder.")
193
+
194
+ # 5. Process Shards (Streamed)
195
+ total_shards = len(target_shards)
196
+ merged_count = 0
197
+
198
+ for idx, shard_file in enumerate(target_shards):
199
+ progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
200
+ logs.append(f"--- Processing {shard_file} ---")
201
+
202
+ # Download Shard
203
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
204
+
205
+ # Load and Merge
206
+ # We use safe_open to read metadata, but load_file for the dict to modify
207
+ # load_file loads to CPU RAM.
208
+ base_tensors = load_file(local_shard, device="cpu")
209
+ modified_tensors = {}
210
+ has_changes = False
211
+
212
+ for key, tensor in base_tensors.items():
213
+ # Match LoRA
214
+ # Handle architectural prefix mismatches (e.g. Ostris repo might rely on folder structure,
215
+ # while LoRA expects "transformer." prefix)
216
+
217
+ # Try exact match first (unlikely for LoRA)
218
+ pair_A, pair_B = match_keys(key, lora_keys)
219
+
220
+ # If not found, try adding/removing common prefixes
221
+ if not pair_A:
222
+ # Attempt to match "blocks.1..." to "transformer.blocks.1..."
223
+ matches = [k for k in lora_keys if key in k] # Simple substring check
224
+ for k in matches:
225
+ if "lora_A" in k or "lora_down" in k:
226
+ pair_A = k
227
+ elif "lora_B" in k or "lora_up" in k:
228
+ pair_B = k
229
+
230
+ if pair_A and pair_B:
231
+ # Apply Merge
232
+ w_a = lora_state[pair_A].float()
233
+ w_b = lora_state[pair_B].float()
234
+
235
+ # Target tensor
236
+ current_tensor = tensor.float()
237
+
238
+ # Dimension Check
239
+ # LoRA = B @ A. Shape should match current_tensor.
240
+ # Sometimes LoRA weights are transposed relative to base depending on training lib.
241
+ delta = (w_b @ w_a) * scale
242
+
243
+ if delta.shape != current_tensor.shape:
244
+ # Try transposing matches
245
+ if delta.T.shape == current_tensor.shape:
246
+ delta = delta.T
247
+ else:
248
+ logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, LoRA Delta: {delta.shape}. Skipping.")
249
+ modified_tensors[key] = tensor
250
+ continue
251
+
252
+ modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
253
+ merged_count += 1
254
+ has_changes = True
255
+ else:
256
+ modified_tensors[key] = tensor
257
+
258
+ # Save and Upload
259
+ if has_changes:
260
+ logs.append(f"Merging complete for shard. Saving...")
261
+ output_path = TempDir / "processed.safetensors"
262
+ save_file(modified_tensors, output_path)
263
+
264
+ api.upload_file(
265
+ path_or_fileobj=output_path,
266
+ path_in_repo=shard_file, # Keep original structure
267
+ repo_id=output_repo,
268
+ repo_type="model",
269
+ token=hf_token
270
+ )
271
+ logs.append(f"Uploaded {shard_file}")
272
+ else:
273
+ # If no changes, just copy the original file to the new repo
274
+ # This saves re-saving the tensor dict
275
+ logs.append(f"No LoRA matches in this shard. Copying original...")
276
+ api.upload_file(
277
+ path_or_fileobj=local_shard,
278
+ path_in_repo=shard_file,
279
+ repo_id=output_repo,
280
+ repo_type="model",
281
+ token=hf_token
282
+ )
283
+
284
+ # Cleanup Memory immediately
285
+ del base_tensors
286
+ del modified_tensors
287
+ if 'delta' in locals(): del delta
288
+ gc.collect()
289
+ os.remove(local_shard)
290
+ if os.path.exists(TempDir / "processed.safetensors"):
291
+ os.remove(TempDir / "processed.safetensors")
292
+
293
+ progress(1.0, desc="Done!")
294
+ logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
295
+ logs.append(f"New model available at: https://huggingface.co/{output_repo}")
296
+
297
+ except Exception as e:
298
+ import traceback
299
+ logs.append(f"\nCRITICAL ERROR: {str(e)}")
300
+ logs.append(traceback.format_exc())
301
+
302
+ finally:
303
+ cleanup_temp()
304
+
305
+ return "\n".join(logs)
306
+
307
+
308
+ # --- UI ---
309
+
310
+ css = """
311
+ .container { max-width: 900px; margin: auto; }
312
+ .header { text-align: center; margin-bottom: 20px; }
313
+ """
314
+
315
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
316
+ gr.Markdown(
317
+ """
318
+ # ⚡ Universal LoRA Merger & Reconstructor
319
+
320
+ Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
321
+ Optimized for CPU-only execution on Hugging Face Spaces.
322
+ """
323
+ )
324
+
325
+ with gr.Group():
326
+ gr.Markdown("### 1. Authentication & Output")
327
+ with gr.Row():
328
+ hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
329
+ output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom")
330
+ is_private = gr.Checkbox(label="Private Repo", value=True)
331
+
332
+ with gr.Group():
333
+ gr.Markdown("### 2. Base Weights (The Target)")
334
+ with gr.Row():
335
+ base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
336
+ base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")
337
+
338
+ with gr.Group():
339
+ gr.Markdown("### 3. LoRA Configuration")
340
+ with gr.Row():
341
+ lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
342
+ scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1)
343
+
344
+ with gr.Group():
345
+ gr.Markdown("### 4. Repository Reconstruction (Optional)")
346
+ gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
347
+ structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")
348
+
349
+ submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
350
+
351
+ output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)
352
+
353
+ submit_btn.click(
354
+ fn=run_merge,
355
+ inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
356
+ outputs=output_log
357
+ )
358
+
359
+ if __name__ == "__main__":
360
+ demo.queue(max_size=1).launch()