AlekseyCalvin commited on
Commit
5af1d7d
·
verified ·
1 Parent(s): 9281147

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +518 -0
app.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional
13
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
+ from safetensors.torch import load_file, save_file
15
+ from tqdm import tqdm
16
+
17
+ # --- Memory Efficient Safetensors ---
18
+ class MemoryEfficientSafeOpen:
19
+ """
20
+ Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
21
+ Essential for running on limited hardware.
22
+ """
23
+ def __init__(self, filename):
24
+ self.filename = filename
25
+ self.file = open(filename, "rb")
26
+ self.header, self.header_size = self._read_header()
27
+
28
+ def __enter__(self):
29
+ return self
30
+
31
+ def __exit__(self, exc_type, exc_val, exc_tb):
32
+ self.file.close()
33
+
34
+ def keys(self) -> list[str]:
35
+ return [k for k in self.header.keys() if k != "__metadata__"]
36
+
37
+ def metadata(self) -> Dict[str, str]:
38
+ return self.header.get("__metadata__", {})
39
+
40
+ def get_tensor(self, key):
41
+ if key not in self.header:
42
+ raise KeyError(f"Tensor '{key}' not found in the file")
43
+ metadata = self.header[key]
44
+ offset_start, offset_end = metadata["data_offsets"]
45
+ self.file.seek(self.header_size + 8 + offset_start)
46
+ tensor_bytes = self.file.read(offset_end - offset_start)
47
+ return self._deserialize_tensor(tensor_bytes, metadata)
48
+
49
+ def _read_header(self):
50
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
51
+ header_json = self.file.read(header_size).decode("utf-8")
52
+ return json.loads(header_json), header_size
53
+
54
+ def _deserialize_tensor(self, tensor_bytes, metadata):
55
+ dtype_map = {
56
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
57
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
58
+ "U8": torch.uint8, "BOOL": torch.bool
59
+ }
60
+ dtype = dtype_map[metadata["dtype"]]
61
+ shape = metadata["shape"]
62
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
63
+
64
+ # --- Constants & Setup ---
65
+ TempDir = Path("./temp_tool")
66
+ os.makedirs(TempDir, exist_ok=True)
67
+ api = HfApi()
68
+
69
+ def cleanup_temp():
70
+ if TempDir.exists():
71
+ shutil.rmtree(TempDir)
72
+ os.makedirs(TempDir, exist_ok=True)
73
+ gc.collect()
74
+
75
+ def download_file(input_path, token, filename=None):
76
+ """Downloads a file from URL or HF Repo."""
77
+ local_path = TempDir / (filename if filename else "model.safetensors")
78
+
79
+ if input_path.startswith("http"):
80
+ print(f"Downloading from URL: {input_path}")
81
+ response = requests.get(input_path, stream=True)
82
+ response.raise_for_status()
83
+ with open(local_path, 'wb') as f:
84
+ for chunk in response.iter_content(chunk_size=8192):
85
+ f.write(chunk)
86
+ else:
87
+ print(f"Downloading from Repo: {input_path}")
88
+ if not filename:
89
+ try:
90
+ files = list_repo_files(repo_id=input_path, token=token)
91
+ safetensors = [f for f in files if f.endswith(".safetensors")]
92
+ if safetensors:
93
+ filename = safetensors[0]
94
+ for f in safetensors:
95
+ if "adapter" in f: filename = f
96
+ else:
97
+ filename = "adapter_model.bin"
98
+ except:
99
+ filename = "adapter_model.safetensors"
100
+
101
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
102
+ downloaded_path = TempDir / filename
103
+ if downloaded_path != local_path:
104
+ if local_path.exists(): os.remove(local_path)
105
+ shutil.move(downloaded_path, local_path)
106
+
107
+ return local_path
108
+
109
+ def get_key_stem(key):
110
+ """
111
+ Normalizes a key to its structural stem by removing known prefixes and suffixes.
112
+ matches 'layers.0.attention' with 'model.diffusion_model.layers.0.attention'.
113
+ """
114
+ key = key.replace(".weight", "").replace(".bias", "")
115
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
116
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
117
+ key = key.replace(".alpha", "")
118
+
119
+ prefixes = [
120
+ "model.diffusion_model.", "diffusion_model.", "model.",
121
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_",
122
+ "base_model.model."
123
+ ]
124
+
125
+ changed = True
126
+ while changed:
127
+ changed = False
128
+ for p in prefixes:
129
+ if key.startswith(p):
130
+ key = key[len(p):]
131
+ changed = True
132
+ return key
133
+
134
+ # =================================================================================
135
+ # TAB 1: UNIVERSAL MERGE (In-Place Memory Optimization)
136
+ # =================================================================================
137
+
138
+ def load_lora_to_memory(lora_path):
139
+ print(f"Loading LoRA from {lora_path}...")
140
+ state_dict = load_file(lora_path, device="cpu")
141
+
142
+ pairs = {}
143
+ alphas = {}
144
+
145
+ for k, v in state_dict.items():
146
+ stem = get_key_stem(k)
147
+ if "alpha" in k:
148
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
149
+ else:
150
+ if stem not in pairs:
151
+ pairs[stem] = {}
152
+ if "lora_down" in k or "lora_A" in k:
153
+ pairs[stem]["down"] = v.float()
154
+ pairs[stem]["rank"] = v.shape[0]
155
+ elif "lora_up" in k or "lora_B" in k:
156
+ pairs[stem]["up"] = v.float()
157
+
158
+ for stem in pairs:
159
+ if stem in alphas:
160
+ pairs[stem]["alpha"] = alphas[stem]
161
+ else:
162
+ if "rank" in pairs[stem]:
163
+ pairs[stem]["alpha"] = float(pairs[stem]["rank"])
164
+ else:
165
+ pairs[stem]["alpha"] = 1.0
166
+
167
+ return pairs
168
+
169
+ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
170
+ print(f"Loading base shard: {base_path}")
171
+ # Load base state into RAM. This is the peak memory usage point.
172
+ base_state = load_file(base_path, device="cpu")
173
+
174
+ lora_keys = set(lora_pairs.keys())
175
+ keys_to_process = list(base_state.keys())
176
+
177
+ for k in keys_to_process:
178
+ v = base_state[k]
179
+ base_stem = get_key_stem(k)
180
+ match = None
181
+
182
+ # 1. Exact Match
183
+ if base_stem in lora_keys:
184
+ match = lora_pairs[base_stem]
185
+ else:
186
+ # 2. Heuristic Match (Z-Image QKV split)
187
+ if "to_q" in base_stem:
188
+ qkv_stem = base_stem.replace("to_q", "qkv")
189
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
190
+ elif "to_k" in base_stem:
191
+ qkv_stem = base_stem.replace("to_k", "qkv")
192
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
193
+ elif "to_v" in base_stem:
194
+ qkv_stem = base_stem.replace("to_v", "qkv")
195
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
196
+
197
+ if match and "down" in match and "up" in match:
198
+ down = match["down"]
199
+ up = match["up"]
200
+ alpha = match["alpha"]
201
+ rank = match["rank"]
202
+
203
+ scaling = scale * (alpha / rank)
204
+
205
+ # Handle Conv 1x1 squeeze
206
+ if len(v.shape) == 4 and len(down.shape) == 2:
207
+ down = down.unsqueeze(-1).unsqueeze(-1)
208
+ up = up.unsqueeze(-1).unsqueeze(-1)
209
+
210
+ try:
211
+ if len(up.shape) == 4:
212
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
213
+ else:
214
+ delta = up @ down
215
+ except:
216
+ delta = up.T @ down
217
+
218
+ delta = delta * scaling
219
+
220
+ # --- Dynamic Reshaping / Slicing ---
221
+ valid_delta = True
222
+
223
+ if delta.shape == v.shape:
224
+ pass
225
+ elif delta.shape[0] == v.shape[0] * 3:
226
+ chunk_size = v.shape[0]
227
+ if "to_q" in k:
228
+ delta = delta[0:chunk_size, ...]
229
+ elif "to_k" in k:
230
+ delta = delta[chunk_size:2*chunk_size, ...]
231
+ elif "to_v" in k:
232
+ delta = delta[2*chunk_size:, ...]
233
+ else:
234
+ valid_delta = False
235
+ elif delta.numel() == v.numel():
236
+ delta = delta.reshape(v.shape)
237
+ else:
238
+ print(f"Skipping {k}: Mismatch. Base: {v.shape}, Delta: {delta.shape}")
239
+ valid_delta = False
240
+
241
+ if valid_delta:
242
+ # IN-PLACE MERGE to save memory
243
+ # 1. Promote to float32
244
+ # 2. Add delta
245
+ # 3. Cast back to original dtype
246
+ # 4. Replace in dict
247
+ orig_dtype = v.dtype
248
+
249
+ # Perform add in float32 to avoid overflow/precision issues
250
+ # Create temp float tensor
251
+ v_float = v.to(torch.float32)
252
+ v_float.add_(delta) # In-place add
253
+
254
+ # Cast back and replace in dict
255
+ base_state[k] = v_float.to(orig_dtype)
256
+
257
+ # Explicit cleanup
258
+ del v_float
259
+ del delta
260
+ # del v # v is a reference to base_state[k], which we just overwrote
261
+
262
+ # Periodic GC to prevent fragmentation OOM
263
+ if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
264
+ gc.collect()
265
+
266
+ save_file(base_state, output_path)
267
+ return True
268
+
269
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
270
+ cleanup_temp()
271
+ login(hf_token)
272
+
273
+ try:
274
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
275
+ except Exception as e:
276
+ return f"Error creating repo: {e}"
277
+
278
+ if structure_repo:
279
+ print("Cloning structure...")
280
+ try:
281
+ files = list_repo_files(repo_id=structure_repo, token=hf_token)
282
+ for f in files:
283
+ if not f.endswith(".safetensors") and not f.endswith(".bin"):
284
+ try:
285
+ path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token)
286
+ api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token)
287
+ except: pass
288
+ except Exception as e:
289
+ print(f"Structure clone warning: {e}")
290
+
291
+ progress(0.1, desc="Loading LoRA...")
292
+ lora_path = download_file(lora_input, hf_token)
293
+ lora_pairs = load_lora_to_memory(lora_path)
294
+
295
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
296
+ shards = [f for f in files if f.endswith(".safetensors")]
297
+ if base_subfolder:
298
+ shards = [f for f in shards if f.startswith(base_subfolder)]
299
+
300
+ if not shards: return "Error: No safetensors found in base."
301
+
302
+ for i, shard in enumerate(shards):
303
+ progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
304
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
305
+ merged_path = TempDir / "merged.safetensors"
306
+
307
+ # Merge Logic
308
+ merge_shard_logic(local_shard, lora_pairs, scale, merged_path)
309
+
310
+ # Upload
311
+ api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
312
+
313
+ # Cleanup immediately
314
+ os.remove(local_shard)
315
+ if merged_path.exists(): os.remove(merged_path)
316
+ gc.collect()
317
+
318
+ return f"Done! Model at https://huggingface.co/{output_repo}"
319
+
320
+ # =================================================================================
321
+ # TAB 2: EXTRACT LORA
322
+ # =================================================================================
323
+
324
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
325
+ org = MemoryEfficientSafeOpen(model_org)
326
+ tuned = MemoryEfficientSafeOpen(model_tuned)
327
+ lora_sd = {}
328
+
329
+ print("Calculating diffs and running SVD (Layer-wise)...")
330
+ keys = list(org.keys())
331
+
332
+ for key in tqdm(keys):
333
+ if key not in tuned.keys(): continue
334
+ mat_org = org.get_tensor(key).float()
335
+ mat_tuned = tuned.get_tensor(key).float()
336
+
337
+ diff = mat_tuned - mat_org
338
+ if torch.max(torch.abs(diff)) < 1e-4: continue
339
+
340
+ out_dim, in_dim = diff.shape[:2]
341
+ r = min(rank, in_dim, out_dim)
342
+ is_conv = len(diff.shape) == 4
343
+ if is_conv: diff = diff.flatten(start_dim=1)
344
+
345
+ try:
346
+ U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
347
+ U = U[:, :r]
348
+ S = S[:r]
349
+ U = U @ torch.diag(S)
350
+ Vh = Vh[:r, :]
351
+
352
+ dist = torch.cat([U.flatten(), Vh.flatten()])
353
+ hi_val = torch.quantile(dist, clamp)
354
+ U = U.clamp(-hi_val, hi_val)
355
+ Vh = Vh.clamp(-hi_val, hi_val)
356
+
357
+ if is_conv:
358
+ U = U.reshape(out_dim, r, 1, 1)
359
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
360
+ else:
361
+ U = U.reshape(out_dim, r)
362
+ Vh = Vh.reshape(r, in_dim)
363
+
364
+ stem = key.replace(".weight", "")
365
+ lora_sd[f"{stem}.lora_up.weight"] = U
366
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
367
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
368
+ except Exception as e:
369
+ print(f"SVD failed for {key}: {e}")
370
+
371
+ out_path = TempDir / "extracted_lora.safetensors"
372
+ save_file(lora_sd, out_path)
373
+ return str(out_path)
374
+
375
+ def task_extract(hf_token, org_repo, tuned_repo, rank, output_repo):
376
+ cleanup_temp()
377
+ login(hf_token)
378
+ print("Downloading models...")
379
+ p1 = download_file(org_repo, hf_token, "org.safetensors")
380
+ p2 = download_file(tuned_repo, hf_token, "tuned.safetensors")
381
+ out = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
382
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
383
+ api.upload_file(path_or_fileobj=out, path_in_repo="extracted_lora.safetensors", repo_id=output_repo, token=hf_token)
384
+ return "Extraction Done."
385
+
386
+ # =================================================================================
387
+ # TAB 3: MERGE ADAPTERS (EMA)
388
+ # =================================================================================
389
+
390
+ def task_merge_adapters(hf_token, lora_urls, beta, output_repo):
391
+ cleanup_temp()
392
+ login(hf_token)
393
+ urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
394
+ paths = []
395
+ for i, url in enumerate(urls):
396
+ paths.append(download_file(url, hf_token, f"adapter_{i}.safetensors"))
397
+
398
+ if not paths: return "No models found"
399
+
400
+ base_sd = load_file(paths[0], device="cpu")
401
+ for k in base_sd:
402
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
403
+
404
+ for i, path in enumerate(paths[1:]):
405
+ print(f"Merging {path}")
406
+ curr = load_file(path, device="cpu")
407
+ for k in base_sd:
408
+ if k in curr and "alpha" not in k:
409
+ base_sd[k] = base_sd[k] * beta + curr[k].float() * (1 - beta)
410
+
411
+ out = TempDir / "merged_adapters.safetensors"
412
+ save_file(base_sd, out)
413
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
414
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=output_repo, token=hf_token)
415
+ return "Done"
416
+
417
+ # =================================================================================
418
+ # TAB 4: RESIZE
419
+ # =================================================================================
420
+
421
+ def task_resize(hf_token, lora_input, new_rank, output_repo):
422
+ cleanup_temp()
423
+ login(hf_token)
424
+ path = download_file(lora_input, hf_token)
425
+ state = load_file(path, device="cpu")
426
+ new_state = {}
427
+ print("Resizing...")
428
+
429
+ groups = {}
430
+ for k in state:
431
+ stem = get_key_stem(k)
432
+ stem_simple = k.split(".lora_")[0]
433
+ if stem_simple not in groups: groups[stem_simple] = {}
434
+ if "lora_down" in k or "lora_A" in k: groups[stem_simple]["down"] = state[k]
435
+ if "lora_up" in k or "lora_B" in k: groups[stem_simple]["up"] = state[k]
436
+
437
+ for stem, g in tqdm(groups.items()):
438
+ if "down" in g and "up" in g:
439
+ down, up = g["down"].float(), g["up"].float()
440
+ if len(down.shape) == 4:
441
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
442
+ flat = merged.flatten(1)
443
+ else:
444
+ merged = up @ down
445
+ flat = merged
446
+
447
+ U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
448
+ U = U[:, :new_rank]
449
+ S = S[:new_rank]
450
+ U = U @ torch.diag(S)
451
+ Vh = Vh[:new_rank, :]
452
+
453
+ if len(down.shape) == 4:
454
+ U = U.reshape(up.shape[0], new_rank, 1, 1)
455
+ Vh = Vh.reshape(new_rank, down.shape[1], down.shape[2], down.shape[3])
456
+
457
+ new_state[f"{stem}.lora_down.weight"] = Vh
458
+ new_state[f"{stem}.lora_up.weight"] = U
459
+ new_state[f"{stem}.alpha"] = torch.tensor(new_rank).float()
460
+
461
+ out = TempDir / "resized.safetensors"
462
+ save_file(new_state, out)
463
+ api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
464
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=output_repo, token=hf_token)
465
+ return "Done"
466
+
467
+ # =================================================================================
468
+ # UI Construction
469
+ # =================================================================================
470
+
471
+ css = ".container { max-width: 900px; margin: auto; }"
472
+
473
+ with gr.Blocks() as demo:
474
+ gr.Markdown("# 🧰 SOONmerge® LoRA Toolkit")
475
+
476
+ with gr.Tabs():
477
+ with gr.Tab("Merge (Z-Image Fix)"):
478
+ t1_token = gr.Textbox(label="Token", type="password")
479
+ t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
480
+ t1_sub = gr.Textbox(label="Subfolder", value="transformer")
481
+ t1_lora = gr.Textbox(label="LoRA")
482
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=-1, maximum=2)
483
+ t1_out = gr.Textbox(label="Output")
484
+ t1_struct = gr.Textbox(label="Structure Repo", value="Tongyi-MAI/Z-Image-Turbo")
485
+ t1_btn = gr.Button("Merge")
486
+ t1_res = gr.Textbox(label="Result")
487
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_out, t1_struct, gr.Checkbox(value=True, visible=False)], t1_res)
488
+
489
+ with gr.Tab("Extract"):
490
+ t2_token = gr.Textbox(label="Token", type="password")
491
+ t2_org = gr.Textbox(label="Original")
492
+ t2_tun = gr.Textbox(label="Tuned")
493
+ t2_rank = gr.Number(label="Rank", value=32)
494
+ t2_out = gr.Textbox(label="Output")
495
+ t2_btn = gr.Button("Extract")
496
+ t2_res = gr.Textbox(label="Result")
497
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
498
+
499
+ with gr.Tab("Merge Adapters"):
500
+ t3_token = gr.Textbox(label="Token", type="password")
501
+ t3_urls = gr.Textbox(label="URLs (comma sep)")
502
+ t3_beta = gr.Slider(label="Beta", value=0.9)
503
+ t3_out = gr.Textbox(label="Output")
504
+ t3_btn = gr.Button("Merge")
505
+ t3_res = gr.Textbox(label="Result")
506
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_res)
507
+
508
+ with gr.Tab("Resize"):
509
+ t4_token = gr.Textbox(label="Token", type="password")
510
+ t4_in = gr.Textbox(label="LoRA")
511
+ t4_rank = gr.Number(label="Rank", value=8)
512
+ t4_out = gr.Textbox(label="Output")
513
+ t4_btn = gr.Button("Resize")
514
+ t4_res = gr.Textbox(label="Result")
515
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_out], t4_res)
516
+
517
+ if __name__ == "__main__":
518
+ demo.queue().launch(css=css, ssr_mode=False)