AlekseyCalvin commited on
Commit
89c201f
·
verified ·
1 Parent(s): 3d46573

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +533 -0
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import re
6
+ import shutil
7
+ import requests
8
+ import json
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
12
+ from safetensors.torch import load_file, save_file
13
+ from tqdm import tqdm
14
+
15
+ # --- Constants & Setup ---
16
+ TempDir = Path("./temp_tool")
17
+ os.makedirs(TempDir, exist_ok=True)
18
+ api = HfApi()
19
+
20
+ def cleanup_temp():
21
+ if TempDir.exists():
22
+ shutil.rmtree(TempDir)
23
+ os.makedirs(TempDir, exist_ok=True)
24
+ gc.collect()
25
+
26
+ # --- Utility Functions ---
27
+
28
+ def download_file(input_path, token, filename=None):
29
+ """Downloads a file from URL or HF Repo."""
30
+ local_path = TempDir / (filename if filename else "model.safetensors")
31
+
32
+ if input_path.startswith("http"):
33
+ print(f"Downloading from URL: {input_path}")
34
+ response = requests.get(input_path, stream=True)
35
+ response.raise_for_status()
36
+ with open(local_path, 'wb') as f:
37
+ for chunk in response.iter_content(chunk_size=8192):
38
+ f.write(chunk)
39
+ else:
40
+ print(f"Downloading from Repo: {input_path}")
41
+ if not filename:
42
+ try:
43
+ files = list_repo_files(repo_id=input_path, token=token)
44
+ safetensors = [f for f in files if f.endswith(".safetensors")]
45
+ if safetensors:
46
+ filename = safetensors[0]
47
+ else:
48
+ filename = "adapter_model.bin"
49
+ except:
50
+ filename = "adapter_model.safetensors"
51
+
52
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
53
+ downloaded_path = TempDir / filename
54
+ if downloaded_path != local_path:
55
+ shutil.move(downloaded_path, local_path)
56
+
57
+ return local_path
58
+
59
+ def get_key_stem(key):
60
+ """
61
+ Normalizes a key to its structural stem.
62
+ Aggressively strips known prefixes to align Comfy/Kohya/Diffusers keys.
63
+ """
64
+ # 1. Remove Suffixes
65
+ key = key.replace(".weight", "").replace(".bias", "")
66
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
67
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
68
+ key = key.replace(".alpha", "")
69
+
70
+ # 2. Remove Common Prefixes
71
+ prefixes = [
72
+ "model.diffusion_model.", "diffusion_model.", "model.",
73
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_"
74
+ ]
75
+
76
+ changed = True
77
+ while changed:
78
+ changed = False
79
+ for p in prefixes:
80
+ if key.startswith(p):
81
+ key = key[len(p):]
82
+ changed = True
83
+ return key
84
+
85
+ # =================================================================================
86
+ # TAB 1: SMART MERGE (Fixes Z-Image QKV)
87
+ # =================================================================================
88
+
89
+ def load_lora_to_memory(lora_path):
90
+ """Loads LoRA and pre-calculates pairs."""
91
+ state_dict = load_file(lora_path, device="cpu")
92
+ alphas = {}
93
+ weights = {}
94
+
95
+ for k, v in state_dict.items():
96
+ if "alpha" in k:
97
+ stem = get_key_stem(k)
98
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
99
+ else:
100
+ weights[k] = v
101
+
102
+ pairs = {}
103
+
104
+ for k, v in weights.items():
105
+ stem = get_key_stem(k)
106
+ if stem not in pairs:
107
+ pairs[stem] = {}
108
+
109
+ if "lora_down" in k or "lora_A" in k:
110
+ pairs[stem]["down"] = v.float()
111
+ pairs[stem]["rank"] = v.shape[0]
112
+ elif "lora_up" in k or "lora_B" in k:
113
+ pairs[stem]["up"] = v.float()
114
+
115
+ for stem in pairs:
116
+ if stem in alphas:
117
+ pairs[stem]["alpha"] = alphas[stem]
118
+ else:
119
+ if "rank" in pairs[stem]:
120
+ pairs[stem]["alpha"] = float(pairs[stem]["rank"])
121
+ else:
122
+ pairs[stem]["alpha"] = 1.0
123
+
124
+ return pairs
125
+
126
+ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
127
+ base_state = load_file(base_path, device="cpu")
128
+ modified_state = {}
129
+ has_modifications = False
130
+
131
+ # Pre-index LoRA stems for fast lookup
132
+ lora_stems = set(lora_pairs.keys())
133
+
134
+ for k, v in base_state.items():
135
+ base_stem = get_key_stem(k)
136
+
137
+ # 1. Direct Match
138
+ match = lora_pairs.get(base_stem)
139
+
140
+ # 2. QKV Match (The Z-Image Fix)
141
+ # If base is `attention.to_q` but LoRA has `attention.qkv`
142
+ chunk_idx = -1
143
+ if not match:
144
+ if "to_q" in base_stem:
145
+ qkv_stem = base_stem.replace("to_q", "qkv")
146
+ if qkv_stem in lora_stems:
147
+ match = lora_pairs[qkv_stem]
148
+ chunk_idx = 0
149
+ elif "to_k" in base_stem:
150
+ qkv_stem = base_stem.replace("to_k", "qkv")
151
+ if qkv_stem in lora_stems:
152
+ match = lora_pairs[qkv_stem]
153
+ chunk_idx = 1
154
+ elif "to_v" in base_stem:
155
+ qkv_stem = base_stem.replace("to_v", "qkv")
156
+ if qkv_stem in lora_stems:
157
+ match = lora_pairs[qkv_stem]
158
+ chunk_idx = 2
159
+
160
+ if match and "down" in match and "up" in match:
161
+ down = match["down"]
162
+ up = match["up"]
163
+
164
+ # Handle Conv2d 1x1
165
+ if len(v.shape) == 4 and len(down.shape) == 2:
166
+ down = down.unsqueeze(-1).unsqueeze(-1)
167
+ up = up.unsqueeze(-1).unsqueeze(-1)
168
+
169
+ scaling = scale * (match["alpha"] / match["rank"])
170
+
171
+ try:
172
+ # Standard LoRA Matmul (Up @ Down)
173
+ if len(up.shape) == 4:
174
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) # Approx for 1x1
175
+ else:
176
+ delta = up @ down
177
+ except:
178
+ delta = up.T @ down # Fallback for transposed weights
179
+
180
+ delta = delta * scaling
181
+
182
+ # --- QKV Chunking Logic ---
183
+ if chunk_idx >= 0:
184
+ # The LoRA delta covers Q+K+V. We need to slice it.
185
+ # Assuming output dim (dim 0) is stacked Q, K, V
186
+ total_out = delta.shape[0]
187
+ chunk_size = total_out // 3
188
+
189
+ start = chunk_idx * chunk_size
190
+ end = start + chunk_size
191
+
192
+ delta = delta[start:end, ...]
193
+ # print(f"Splitting QKV for {k}: chunk {chunk_idx}")
194
+
195
+ # Final Shape Check
196
+ if delta.shape != v.shape:
197
+ if delta.numel() == v.numel():
198
+ delta = delta.reshape(v.shape)
199
+ else:
200
+ print(f"Skipping {k}: Shape mismatch Base {v.shape} vs Delta {delta.shape}")
201
+ modified_state[k] = v
202
+ continue
203
+
204
+ modified_state[k] = v.float() + delta
205
+ modified_state[k] = modified_state[k].to(v.dtype)
206
+ has_modifications = True
207
+ else:
208
+ modified_state[k] = v
209
+
210
+ if has_modifications:
211
+ save_file(modified_state, output_path)
212
+ return True
213
+ return False
214
+
215
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
216
+ cleanup_temp()
217
+ login(hf_token)
218
+
219
+ try:
220
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
221
+ except Exception as e:
222
+ return f"Error creating repo: {e}"
223
+
224
+ if structure_repo:
225
+ print("Cloning structure...")
226
+ try:
227
+ files = list_repo_files(repo_id=structure_repo, token=hf_token)
228
+ for f in files:
229
+ if not f.endswith(".safetensors") and not f.endswith(".bin"):
230
+ try:
231
+ path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token)
232
+ api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token)
233
+ except: pass
234
+ except Exception as e:
235
+ print(f"Structure clone warning: {e}")
236
+
237
+ progress(0.1, desc="Loading LoRA...")
238
+ lora_path = download_file(lora_input, hf_token)
239
+ lora_pairs = load_lora_to_memory(lora_path)
240
+ print(f"Loaded LoRA with {len(lora_pairs)} modules.")
241
+
242
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
243
+ shards = [f for f in files if f.endswith(".safetensors")]
244
+ if base_subfolder:
245
+ shards = [f for f in shards if f.startswith(base_subfolder)]
246
+
247
+ if not shards:
248
+ return "Error: No model shards found in base repo."
249
+
250
+ for i, shard in enumerate(shards):
251
+ progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
252
+ print(f"Processing {shard}...")
253
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
254
+
255
+ merged_path = TempDir / "merged.safetensors"
256
+ success = merge_shard_logic(local_shard, lora_pairs, scale, merged_path)
257
+
258
+ # Upload preserving directory structure
259
+ api.upload_file(path_or_fileobj=merged_path if success else local_shard, path_in_repo=shard, repo_id=output_repo, token=hf_token)
260
+
261
+ os.remove(local_shard)
262
+ if merged_path.exists(): os.remove(merged_path)
263
+ gc.collect()
264
+
265
+ return f"Done! Model at https://huggingface.co/{output_repo}"
266
+
267
+ # =================================================================================
268
+ # TAB 2: EXTRACT LORA
269
+ # =================================================================================
270
+
271
+ def extract_lora(model_org, model_tuned, rank, conv_rank, clamp):
272
+ try:
273
+ org_state = load_file(model_org, device="cpu")
274
+ tuned_state = load_file(model_tuned, device="cpu")
275
+ except:
276
+ return None, "Error: Could not load models."
277
+
278
+ lora_sd = {}
279
+ print("Calculating diffs and running SVD...")
280
+
281
+ for key in tqdm(org_state.keys()):
282
+ if key not in tuned_state: continue
283
+
284
+ # Calculate diff
285
+ mat = tuned_state[key].float() - org_state[key].float()
286
+ if torch.max(torch.abs(mat)) < 1e-4: continue
287
+
288
+ out_dim, in_dim = mat.shape[:2]
289
+ rank_to_use = min(rank, in_dim, out_dim)
290
+
291
+ is_conv = len(mat.shape) == 4
292
+ if is_conv: mat = mat.flatten(start_dim=1)
293
+
294
+ try:
295
+ # SVD
296
+ U, S, Vh = torch.linalg.svd(mat, full_matrices=False)
297
+ U = U[:, :rank_to_use]
298
+ S = S[:rank_to_use]
299
+ U = U @ torch.diag(S)
300
+ Vh = Vh[:rank_to_use, :]
301
+
302
+ # Clamp (Kohya trick)
303
+ dist = torch.cat([U.flatten(), Vh.flatten()])
304
+ hi_val = torch.quantile(dist, clamp)
305
+ low_val = -hi_val
306
+ U = U.clamp(low_val, hi_val)
307
+ Vh = Vh.clamp(low_val, hi_val)
308
+
309
+ # Reshape
310
+ if is_conv:
311
+ U = U.reshape(out_dim, rank_to_use, 1, 1)
312
+ Vh = Vh.reshape(rank_to_use, in_dim, mat.shape[0], mat.shape[1])
313
+ else:
314
+ U = U.reshape(out_dim, rank_to_use)
315
+ Vh = Vh.reshape(rank_to_use, in_dim)
316
+
317
+ stem = key.replace(".weight", "")
318
+ lora_sd[f"{stem}.lora_up.weight"] = U
319
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
320
+ lora_sd[f"{stem}.alpha"] = torch.tensor(rank_to_use).float()
321
+
322
+ except Exception as e:
323
+ print(f"SVD failed for {key}: {e}")
324
+
325
+ out_path = TempDir / "extracted_lora.safetensors"
326
+ save_file(lora_sd, out_path)
327
+ return str(out_path), "Success"
328
+
329
+ def task_extract(hf_token, org_repo, tuned_repo, rank, output_repo):
330
+ cleanup_temp()
331
+ login(hf_token)
332
+ print("Downloading Original...")
333
+ org_path = download_file(org_repo, hf_token, "original.safetensors")
334
+ print("Downloading Tuned...")
335
+ tuned_path = download_file(tuned_repo, hf_token, "tuned.safetensors")
336
+
337
+ path, msg = extract_lora(org_path, tuned_path, int(rank), int(rank), 0.99)
338
+
339
+ if path:
340
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
341
+ api.upload_file(path_or_fileobj=path, path_in_repo="extracted_lora.safetensors", repo_id=output_repo, token=hf_token)
342
+ return "Extraction Done."
343
+ return msg
344
+
345
+ # =================================================================================
346
+ # TAB 3: MERGE ADAPTERS (Post-Hoc EMA)
347
+ # =================================================================================
348
+
349
+ def merge_adapters_ema(lora_paths, beta, output_path):
350
+ """
351
+ Implements Power Function EMA merging from lora_post_hoc_ema.py
352
+ """
353
+ # Sort files (assuming temporal order is desired, though we rely on input list order)
354
+ # lora_paths are typically passed in order.
355
+
356
+ if not lora_paths: return False
357
+
358
+ print(f"Loading base: {lora_paths[0]}")
359
+ base_state = load_file(lora_paths[0], device="cpu")
360
+
361
+ # Convert to float32 for merging
362
+ for k in base_state:
363
+ if base_state[k].dtype.is_floating_point:
364
+ base_state[k] = base_state[k].float()
365
+
366
+ ema_count = len(lora_paths) - 1
367
+
368
+ for i, path in enumerate(lora_paths[1:]):
369
+ print(f"Merging {path}...")
370
+ current_state = load_file(path, device="cpu")
371
+
372
+ # Simple Beta Decay (Can be extended to Power Function if sigma_rel is needed)
373
+ # Using a fixed beta or linear interp as per user request
374
+
375
+ # Default simple EMA: state = state * beta + new * (1-beta)
376
+ # Kohya's script allows dynamic beta. Let's use the user provided beta.
377
+
378
+ for k in base_state:
379
+ if k in current_state:
380
+ if "alpha" in k: continue # Alphas should match
381
+
382
+ curr_val = current_state[k].float()
383
+ base_state[k] = base_state[k] * beta + curr_val * (1 - beta)
384
+
385
+ save_file(base_state, output_path)
386
+ return True
387
+
388
+ def task_merge_adapters(hf_token, lora_urls, beta, output_repo):
389
+ cleanup_temp()
390
+ login(hf_token)
391
+
392
+ urls = [url.strip() for url in lora_urls.split(",")]
393
+ local_paths = []
394
+
395
+ for i, url in enumerate(urls):
396
+ if not url: continue
397
+ print(f"Downloading Adapter {i+1}...")
398
+ # handle resolve urls
399
+ path = download_file(url, hf_token, f"adapter_{i}.safetensors")
400
+ local_paths.append(path)
401
+
402
+ out_path = TempDir / "merged_adapters.safetensors"
403
+ success = merge_adapters_ema(local_paths, beta, out_path)
404
+
405
+ if success:
406
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
407
+ api.upload_file(path_or_fileobj=out_path, path_in_repo="merged_adapters_ema.safetensors", repo_id=output_repo, token=hf_token)
408
+ return "Adapter Merge Done."
409
+ return "Error merging adapters."
410
+
411
+ # =================================================================================
412
+ # TAB 4: RESIZE LORA
413
+ # =================================================================================
414
+
415
+ def task_resize(hf_token, lora_input, new_rank, output_repo):
416
+ cleanup_temp()
417
+ login(hf_token)
418
+
419
+ path = download_file(lora_input, hf_token)
420
+ state = load_file(path, device="cpu")
421
+ new_state = {}
422
+
423
+ print("Resizing...")
424
+ stems = set()
425
+ for k in state.keys():
426
+ stems.add(get_key_stem(k))
427
+
428
+ for stem in tqdm(stems):
429
+ down_key = None
430
+ up_key = None
431
+
432
+ # Fuzzy finder for the raw keys
433
+ for k in state:
434
+ if stem in k and ("lora_down" in k or "lora_A" in k): down_key = k
435
+ if stem in k and ("lora_up" in k or "lora_B" in k): up_key = k
436
+
437
+ if down_key and up_key:
438
+ down = state[down_key].float()
439
+ up = state[up_key].float()
440
+
441
+ if len(down.shape) == 2:
442
+ merged = up @ down
443
+ else:
444
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
445
+
446
+ # Re-SVD
447
+ U, S, Vh = torch.linalg.svd(merged.flatten(1), full_matrices=False)
448
+ U = U[:, :new_rank]
449
+ S = S[:new_rank]
450
+ U = U @ torch.diag(S)
451
+ Vh = Vh[:new_rank, :]
452
+
453
+ new_state[down_key] = Vh
454
+ new_state[up_key] = U
455
+ # Find alpha key
456
+ for k in state:
457
+ if stem in k and "alpha" in k:
458
+ new_state[k] = torch.tensor(new_rank).float()
459
+
460
+ out = TempDir / "resized.safetensors"
461
+ save_file(new_state, out)
462
+
463
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
464
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized_lora.safetensors", repo_id=output_repo, token=hf_token)
465
+ return "Resize Done."
466
+
467
+ # =================================================================================
468
+ # UI
469
+ # =================================================================================
470
+
471
+ css = """
472
+ .container { max-width: 900px; margin: auto; }
473
+ """
474
+
475
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
476
+ gr.Markdown("# 🧰 SOONmerge® Toolkit")
477
+ gr.Markdown("Includes: Smart QKV Un-fusing, Post-Hoc EMA, Adapter Merging, Resizing, and Extraction.")
478
+
479
+ with gr.Tabs():
480
+ # --- TAB 1 ---
481
+ with gr.Tab("Merge LoRA into Base"):
482
+ gr.Markdown("Supports Z-Image Fused QKV LoRAs -> Split Base.")
483
+ t1_token = gr.Textbox(label="HF Token", type="password")
484
+ with gr.Row():
485
+ t1_base = gr.Textbox(label="Base Model Repo", placeholder="ostris/Z-Image-De-Turbo")
486
+ t1_sub = gr.Textbox(label="Subfolder (Optional)", placeholder="transformer")
487
+ with gr.Row():
488
+ t1_lora = gr.Textbox(label="LoRA Repo/URL")
489
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=-1, maximum=2)
490
+ t1_out = gr.Textbox(label="Output Repo")
491
+ t1_struct = gr.Textbox(label="Structure Repo (Optional)", placeholder="Tongyi-MAI/Z-Image-Turbo")
492
+ t1_btn = gr.Button("Merge")
493
+ t1_log = gr.Textbox(label="Log", interactive=False)
494
+
495
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_out, t1_struct, gr.Checkbox(value=True, visible=False)], t1_log)
496
+
497
+ # --- TAB 2 ---
498
+ with gr.Tab("Extract LoRA"):
499
+ t2_token = gr.Textbox(label="HF Token", type="password")
500
+ t2_org = gr.Textbox(label="Original Model Repo/URL")
501
+ t2_tuned = gr.Textbox(label="Tuned Model Repo/URL")
502
+ t2_rank = gr.Number(label="Rank", value=32)
503
+ t2_out = gr.Textbox(label="Output Repo")
504
+ t2_btn = gr.Button("Extract")
505
+ t2_log = gr.Textbox(label="Log")
506
+
507
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tuned, t2_rank, t2_out], t2_log)
508
+
509
+ # --- TAB 3 ---
510
+ with gr.Tab("Merge Adapters (EMA)"):
511
+ gr.Markdown("Post-Hoc EMA Merge: Combined multiple LoRAs into one file.")
512
+ t3_token = gr.Textbox(label="HF Token", type="password")
513
+ t3_urls = gr.Textbox(label="LoRA URLs (comma separated)", placeholder="http://...lora1.safetensors, http://...lora2.safetensors")
514
+ t3_beta = gr.Slider(label="Beta (Decay)", value=0.95, minimum=0.0, maximum=1.0)
515
+ t3_out = gr.Textbox(label="Output Repo")
516
+ t3_btn = gr.Button("Merge Adapters")
517
+ t3_log = gr.Textbox(label="Log")
518
+
519
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_log)
520
+
521
+ # --- TAB 4 ---
522
+ with gr.Tab("Resize LoRA"):
523
+ t4_token = gr.Textbox(label="HF Token", type="password")
524
+ t4_in = gr.Textbox(label="LoRA Repo/URL")
525
+ t4_rank = gr.Number(label="Target Rank", value=8)
526
+ t4_out = gr.Textbox(label="Output Repo")
527
+ t4_btn = gr.Button("Resize")
528
+ t4_log = gr.Textbox(label="Log")
529
+
530
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_out], t4_log)
531
+
532
+ if __name__ == "__main__":
533
+ demo.queue().launch()