AlekseyCalvin commited on
Commit
d9ef66b
·
verified ·
1 Parent(s): 2db902e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -79
app.py CHANGED
@@ -46,15 +46,12 @@ def download_lora(lora_input, hf_token):
46
  else:
47
  # Repo ID download
48
  print(f"Downloading LoRA from Repo: {lora_input}")
49
- # Try finding the safetensors file
50
  try:
51
  return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
52
  except:
53
- # Fallback for diffusion models which might use different names
54
  files = list_repo_files(repo_id=lora_input, token=hf_token)
55
  safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
56
  if not safe_files:
57
- # Last ditch: grab the first safetensors
58
  safe_files = [f for f in files if f.endswith(".safetensors")]
59
 
60
  if not safe_files:
@@ -63,28 +60,11 @@ def download_lora(lora_input, hf_token):
63
  return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
64
 
65
  def load_lora_weights(path):
66
- """Loads LoRA weights and attempts to determine rank/alpha."""
67
  tensors = load_file(path, device="cpu")
68
- # Basic metadata extraction could happen here if needed,
69
- # but for raw merging we mainly need the state dict.
70
  return tensors
71
 
72
  def match_keys(base_key, lora_keys):
73
- """
74
- Heuristic matching.
75
- 1. Exact match (rare for LoRA).
76
- 2. LoRA naming conventions (lora_A, lora_B, lora_down, etc).
77
- """
78
- # Common LoRA naming patterns
79
- # pattern: base_key.lora_A.weight
80
- # pattern: base_key + ".0.lora_B.weight" (sometimes happens)
81
-
82
  matches = {}
83
-
84
- # Cleaning the keys for comparison
85
- # If base is "transformer.blocks.0.weight"
86
- # LoRA might be "transformer.blocks.0.lora_A.weight"
87
-
88
  candidates = [k for k in lora_keys if base_key in k]
89
 
90
  pair_A = None
@@ -99,11 +79,9 @@ def match_keys(base_key, lora_keys):
99
  return pair_A, pair_B
100
 
101
  def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
102
- """Copies config/tokenizer/scheduler files from source to target."""
103
  print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
104
  files = list_repo_files(repo_id=src_repo, token=token)
105
 
106
- # Filter out heavy weights
107
  files_to_copy = [
108
  f for f in files
109
  if not f.endswith(".safetensors")
@@ -116,7 +94,6 @@ def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
116
 
117
  for f in tqdm(files_to_copy, desc="Copying configs"):
118
  try:
119
- # We download to memory/temp and upload immediately
120
  local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
121
  api.upload_file(
122
  path_or_fileobj=local,
@@ -154,9 +131,9 @@ def run_merge(
154
  except Exception as e:
155
  return "\n".join(logs) + f"\nError creating repo: {e}"
156
 
157
- # 2. Replicate Structure (If requested)
158
  if structure_repo.strip():
159
- progress(0.1, desc="Cloning Model Structure (Configs)...")
160
  logs.append(f"Cloning configuration from {structure_repo}...")
161
  copy_auxiliary_files(structure_repo, output_repo, hf_token)
162
  logs.append("Configuration files copied.")
@@ -173,25 +150,19 @@ def run_merge(
173
  progress(0.3, desc="Analyzing Base Model...")
174
  all_files = list_repo_files(repo_id=base_repo, token=hf_token)
175
 
176
- # Filter for safetensors in the specific subfolder (if provided)
177
  target_shards = []
178
  for f in all_files:
179
  if not f.endswith(".safetensors"):
180
  continue
181
-
182
- # Check subfolder constraint
183
- if base_subfolder.strip():
184
- # Normalize paths
185
- if not f.startswith(base_subfolder.strip("/")):
186
- continue
187
-
188
  target_shards.append(f)
189
 
190
  logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
191
  if not target_shards:
192
  raise ValueError("No safetensors found in the specified base repo/subfolder.")
193
 
194
- # 5. Process Shards (Streamed)
195
  total_shards = len(target_shards)
196
  merged_count = 0
197
 
@@ -199,28 +170,16 @@ def run_merge(
199
  progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
200
  logs.append(f"--- Processing {shard_file} ---")
201
 
202
- # Download Shard
203
  local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
204
-
205
- # Load and Merge
206
- # We use safe_open to read metadata, but load_file for the dict to modify
207
- # load_file loads to CPU RAM.
208
  base_tensors = load_file(local_shard, device="cpu")
209
  modified_tensors = {}
210
  has_changes = False
211
 
212
  for key, tensor in base_tensors.items():
213
- # Match LoRA
214
- # Handle architectural prefix mismatches (e.g. Ostris repo might rely on folder structure,
215
- # while LoRA expects "transformer." prefix)
216
-
217
- # Try exact match first (unlikely for LoRA)
218
  pair_A, pair_B = match_keys(key, lora_keys)
219
 
220
- # If not found, try adding/removing common prefixes
221
  if not pair_A:
222
- # Attempt to match "blocks.1..." to "transformer.blocks.1..."
223
- matches = [k for k in lora_keys if key in k] # Simple substring check
224
  for k in matches:
225
  if "lora_A" in k or "lora_down" in k:
226
  pair_A = k
@@ -228,24 +187,16 @@ def run_merge(
228
  pair_B = k
229
 
230
  if pair_A and pair_B:
231
- # Apply Merge
232
  w_a = lora_state[pair_A].float()
233
  w_b = lora_state[pair_B].float()
234
-
235
- # Target tensor
236
  current_tensor = tensor.float()
237
-
238
- # Dimension Check
239
- # LoRA = B @ A. Shape should match current_tensor.
240
- # Sometimes LoRA weights are transposed relative to base depending on training lib.
241
  delta = (w_b @ w_a) * scale
242
 
243
  if delta.shape != current_tensor.shape:
244
- # Try transposing matches
245
  if delta.T.shape == current_tensor.shape:
246
  delta = delta.T
247
  else:
248
- logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, LoRA Delta: {delta.shape}. Skipping.")
249
  modified_tensors[key] = tensor
250
  continue
251
 
@@ -255,33 +206,16 @@ def run_merge(
255
  else:
256
  modified_tensors[key] = tensor
257
 
258
- # Save and Upload
259
  if has_changes:
260
  logs.append(f"Merging complete for shard. Saving...")
261
  output_path = TempDir / "processed.safetensors"
262
  save_file(modified_tensors, output_path)
263
-
264
- api.upload_file(
265
- path_or_fileobj=output_path,
266
- path_in_repo=shard_file, # Keep original structure
267
- repo_id=output_repo,
268
- repo_type="model",
269
- token=hf_token
270
- )
271
  logs.append(f"Uploaded {shard_file}")
272
  else:
273
- # If no changes, just copy the original file to the new repo
274
- # This saves re-saving the tensor dict
275
  logs.append(f"No LoRA matches in this shard. Copying original...")
276
- api.upload_file(
277
- path_or_fileobj=local_shard,
278
- path_in_repo=shard_file,
279
- repo_id=output_repo,
280
- repo_type="model",
281
- token=hf_token
282
- )
283
 
284
- # Cleanup Memory immediately
285
  del base_tensors
286
  del modified_tensors
287
  if 'delta' in locals(): del delta
@@ -301,10 +235,9 @@ def run_merge(
301
 
302
  finally:
303
  cleanup_temp()
304
-
305
  return "\n".join(logs)
306
 
307
-
308
  # --- UI ---
309
 
310
  css = """
@@ -312,7 +245,8 @@ css = """
312
  .header { text-align: center; margin-bottom: 20px; }
313
  """
314
 
315
- with gr.Blocks(css=css) as demo:
 
316
  gr.Markdown(
317
  """
318
  # ⚡ Universal LoRA Merger & Reconstructor
@@ -357,4 +291,5 @@ with gr.Blocks(css=css) as demo:
357
  )
358
 
359
  if __name__ == "__main__":
360
- demo.queue(max_size=1).launch()
 
 
46
  else:
47
  # Repo ID download
48
  print(f"Downloading LoRA from Repo: {lora_input}")
 
49
  try:
50
  return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
51
  except:
 
52
  files = list_repo_files(repo_id=lora_input, token=hf_token)
53
  safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
54
  if not safe_files:
 
55
  safe_files = [f for f in files if f.endswith(".safetensors")]
56
 
57
  if not safe_files:
 
60
  return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
61
 
62
  def load_lora_weights(path):
 
63
  tensors = load_file(path, device="cpu")
 
 
64
  return tensors
65
 
66
  def match_keys(base_key, lora_keys):
 
 
 
 
 
 
 
 
 
67
  matches = {}
 
 
 
 
 
68
  candidates = [k for k in lora_keys if base_key in k]
69
 
70
  pair_A = None
 
79
  return pair_A, pair_B
80
 
81
  def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
 
82
  print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
83
  files = list_repo_files(repo_id=src_repo, token=token)
84
 
 
85
  files_to_copy = [
86
  f for f in files
87
  if not f.endswith(".safetensors")
 
94
 
95
  for f in tqdm(files_to_copy, desc="Copying configs"):
96
  try:
 
97
  local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
98
  api.upload_file(
99
  path_or_fileobj=local,
 
131
  except Exception as e:
132
  return "\n".join(logs) + f"\nError creating repo: {e}"
133
 
134
+ # 2. Replicate Structure
135
  if structure_repo.strip():
136
+ progress(0.1, desc="Cloning Model Structure...")
137
  logs.append(f"Cloning configuration from {structure_repo}...")
138
  copy_auxiliary_files(structure_repo, output_repo, hf_token)
139
  logs.append("Configuration files copied.")
 
150
  progress(0.3, desc="Analyzing Base Model...")
151
  all_files = list_repo_files(repo_id=base_repo, token=hf_token)
152
 
 
153
  target_shards = []
154
  for f in all_files:
155
  if not f.endswith(".safetensors"):
156
  continue
157
+ if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")):
158
+ continue
 
 
 
 
 
159
  target_shards.append(f)
160
 
161
  logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
162
  if not target_shards:
163
  raise ValueError("No safetensors found in the specified base repo/subfolder.")
164
 
165
+ # 5. Process Shards
166
  total_shards = len(target_shards)
167
  merged_count = 0
168
 
 
170
  progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
171
  logs.append(f"--- Processing {shard_file} ---")
172
 
 
173
  local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
 
 
 
 
174
  base_tensors = load_file(local_shard, device="cpu")
175
  modified_tensors = {}
176
  has_changes = False
177
 
178
  for key, tensor in base_tensors.items():
 
 
 
 
 
179
  pair_A, pair_B = match_keys(key, lora_keys)
180
 
 
181
  if not pair_A:
182
+ matches = [k for k in lora_keys if key in k]
 
183
  for k in matches:
184
  if "lora_A" in k or "lora_down" in k:
185
  pair_A = k
 
187
  pair_B = k
188
 
189
  if pair_A and pair_B:
 
190
  w_a = lora_state[pair_A].float()
191
  w_b = lora_state[pair_B].float()
 
 
192
  current_tensor = tensor.float()
 
 
 
 
193
  delta = (w_b @ w_a) * scale
194
 
195
  if delta.shape != current_tensor.shape:
 
196
  if delta.T.shape == current_tensor.shape:
197
  delta = delta.T
198
  else:
199
+ logs.append(f"Warning: Shape mismatch for {key}. Skipping.")
200
  modified_tensors[key] = tensor
201
  continue
202
 
 
206
  else:
207
  modified_tensors[key] = tensor
208
 
 
209
  if has_changes:
210
  logs.append(f"Merging complete for shard. Saving...")
211
  output_path = TempDir / "processed.safetensors"
212
  save_file(modified_tensors, output_path)
213
+ api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
 
 
 
 
 
 
 
214
  logs.append(f"Uploaded {shard_file}")
215
  else:
 
 
216
  logs.append(f"No LoRA matches in this shard. Copying original...")
217
+ api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
 
 
 
 
 
 
218
 
 
219
  del base_tensors
220
  del modified_tensors
221
  if 'delta' in locals(): del delta
 
235
 
236
  finally:
237
  cleanup_temp()
238
+
239
  return "\n".join(logs)
240
 
 
241
  # --- UI ---
242
 
243
  css = """
 
245
  .header { text-align: center; margin-bottom: 20px; }
246
  """
247
 
248
+ # NOTE: Removed 'css' and 'theme' from gr.Blocks() to be compatible with latest Gradio versions.
249
+ with gr.Blocks() as demo:
250
  gr.Markdown(
251
  """
252
  # ⚡ Universal LoRA Merger & Reconstructor
 
291
  )
292
 
293
  if __name__ == "__main__":
294
+ # CSS is now passed here in the launch method
295
+ demo.queue(max_size=1).launch(css=css)