AlekseyCalvin commited on
Commit
156e3f3
·
verified ·
1 Parent(s): f1167d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -31
app.py CHANGED
@@ -72,17 +72,33 @@ def cleanup_temp():
72
  os.makedirs(TempDir, exist_ok=True)
73
  gc.collect()
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def download_file(input_path, token, filename=None):
76
  """Downloads a file from URL or HF Repo."""
77
  local_path = TempDir / (filename if filename else "model.safetensors")
78
 
79
  if input_path.startswith("http"):
80
  print(f"Downloading from URL: {input_path}")
81
- response = requests.get(input_path, stream=True)
82
- response.raise_for_status()
83
- with open(local_path, 'wb') as f:
84
- for chunk in response.iter_content(chunk_size=8192):
85
- f.write(chunk)
 
 
 
86
  else:
87
  print(f"Downloading from Repo: {input_path}")
88
  if not filename:
@@ -95,22 +111,24 @@ def download_file(input_path, token, filename=None):
95
  if "adapter" in f: filename = f
96
  else:
97
  filename = "adapter_model.bin"
98
- except:
99
  filename = "adapter_model.safetensors"
100
 
101
- hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
102
- downloaded_path = TempDir / filename
103
- if downloaded_path != local_path:
104
- if local_path.exists(): os.remove(local_path)
105
- shutil.move(downloaded_path, local_path)
106
-
 
 
 
 
 
 
107
  return local_path
108
 
109
  def get_key_stem(key):
110
- """
111
- Normalizes a key to its structural stem by removing known prefixes and suffixes.
112
- matches 'layers.0.attention' with 'model.diffusion_model.layers.0.attention'.
113
- """
114
  key = key.replace(".weight", "").replace(".bias", "")
115
  key = key.replace(".lora_down", "").replace(".lora_up", "")
116
  key = key.replace(".lora_A", "").replace(".lora_B", "")
@@ -135,8 +153,8 @@ def get_key_stem(key):
135
  # TAB 1: UNIVERSAL MERGE (In-Place Memory Optimization)
136
  # =================================================================================
137
 
138
- def load_lora_to_memory(lora_path):
139
- print(f"Loading LoRA from {lora_path}...")
140
  state_dict = load_file(lora_path, device="cpu")
141
 
142
  pairs = {}
@@ -149,11 +167,15 @@ def load_lora_to_memory(lora_path):
149
  else:
150
  if stem not in pairs:
151
  pairs[stem] = {}
 
 
 
 
152
  if "lora_down" in k or "lora_A" in k:
153
- pairs[stem]["down"] = v.float()
154
  pairs[stem]["rank"] = v.shape[0]
155
  elif "lora_up" in k or "lora_B" in k:
156
- pairs[stem]["up"] = v.float()
157
 
158
  for stem in pairs:
159
  if stem in alphas:
@@ -166,15 +188,15 @@ def load_lora_to_memory(lora_path):
166
 
167
  return pairs
168
 
169
- def merge_shard_logic(base_path, lora_pairs, scale, output_path):
170
  print(f"Loading base shard: {base_path}")
171
- # Load base state into RAM. This is the peak memory usage point.
172
  base_state = load_file(base_path, device="cpu")
173
 
174
  lora_keys = set(lora_pairs.keys())
175
  keys_to_process = list(base_state.keys())
176
 
177
  for k in keys_to_process:
 
178
  v = base_state[k]
179
  base_stem = get_key_stem(k)
180
  match = None
@@ -195,6 +217,7 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
195
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
196
 
197
  if match and "down" in match and "up" in match:
 
198
  down = match["down"]
199
  up = match["up"]
200
  alpha = match["alpha"]
@@ -207,6 +230,7 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
207
  down = down.unsqueeze(-1).unsqueeze(-1)
208
  up = up.unsqueeze(-1).unsqueeze(-1)
209
 
 
210
  try:
211
  if len(up.shape) == 4:
212
  delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
@@ -217,9 +241,9 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
217
 
218
  delta = delta * scaling
219
 
220
- # --- Dynamic Reshaping / Slicing ---
221
  valid_delta = True
222
 
 
223
  if delta.shape == v.shape:
224
  pass
225
  elif delta.shape[0] == v.shape[0] * 3:
@@ -260,11 +284,11 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
260
  save_file(base_state, output_path)
261
  return True
262
 
263
- def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
264
  cleanup_temp()
265
  login(hf_token)
266
-
267
- # Determine Dtype
268
  if precision == "bf16":
269
  dtype = torch.bfloat16
270
  elif precision == "fp16":
@@ -273,7 +297,7 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_re
273
  dtype = torch.float32
274
 
275
  print(f"Selected Precision: {dtype}")
276
-
277
  try:
278
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
279
  except Exception as e:
@@ -311,14 +335,12 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_re
311
  progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
312
  local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
313
  merged_path = TempDir / "merged.safetensors"
314
-
315
- # Pass precision preference
316
  merge_shard_logic(local_shard, lora_pairs, scale, merged_path, precision_dtype=dtype)
317
 
318
- # Upload
319
  api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
320
 
321
- # Cleanup immediately
322
  os.remove(local_shard)
323
  if merged_path.exists(): os.remove(merged_path)
324
  gc.collect()
 
72
  os.makedirs(TempDir, exist_ok=True)
73
  gc.collect()
74
 
75
+ def verify_safetensors(path):
76
+ """Checks if a file is a valid safetensors file."""
77
+ try:
78
+ with open(path, "rb") as f:
79
+ header_size_bytes = f.read(8)
80
+ if len(header_size_bytes) != 8: return False
81
+ header_size = struct.unpack("<Q", header_size_bytes)[0]
82
+ if header_size > os.path.getsize(path) or header_size <= 0:
83
+ return False
84
+ return True
85
+ except:
86
+ return False
87
+
88
  def download_file(input_path, token, filename=None):
89
  """Downloads a file from URL or HF Repo."""
90
  local_path = TempDir / (filename if filename else "model.safetensors")
91
 
92
  if input_path.startswith("http"):
93
  print(f"Downloading from URL: {input_path}")
94
+ try:
95
+ response = requests.get(input_path, stream=True, timeout=30)
96
+ response.raise_for_status()
97
+ with open(local_path, 'wb') as f:
98
+ for chunk in response.iter_content(chunk_size=8192):
99
+ f.write(chunk)
100
+ except Exception as e:
101
+ raise ValueError(f"Failed to download URL. Check your link. Error: {e}")
102
  else:
103
  print(f"Downloading from Repo: {input_path}")
104
  if not filename:
 
111
  if "adapter" in f: filename = f
112
  else:
113
  filename = "adapter_model.bin"
114
+ except Exception as e:
115
  filename = "adapter_model.safetensors"
116
 
117
+ try:
118
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
119
+ downloaded_path = TempDir / filename
120
+ if downloaded_path != local_path:
121
+ if local_path.exists(): os.remove(local_path)
122
+ shutil.move(downloaded_path, local_path)
123
+ except Exception as e:
124
+ raise ValueError(f"Failed to download from HF Repo. Check ID/Token. Error: {e}")
125
+
126
+ if not verify_safetensors(local_path):
127
+ raise ValueError(f"Downloaded file is NOT a valid safetensors file. Check your URL/Repo. (File size: {os.path.getsize(local_path)} bytes)")
128
+
129
  return local_path
130
 
131
  def get_key_stem(key):
 
 
 
 
132
  key = key.replace(".weight", "").replace(".bias", "")
133
  key = key.replace(".lora_down", "").replace(".lora_up", "")
134
  key = key.replace(".lora_A", "").replace(".lora_B", "")
 
153
  # TAB 1: UNIVERSAL MERGE (In-Place Memory Optimization)
154
  # =================================================================================
155
 
156
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
157
+ print(f"Loading LoRA from {lora_path} in {precision_dtype}...")
158
  state_dict = load_file(lora_path, device="cpu")
159
 
160
  pairs = {}
 
167
  else:
168
  if stem not in pairs:
169
  pairs[stem] = {}
170
+
171
+ # Cast immediately to save RAM
172
+ tensor_low = v.to(dtype=precision_dtype)
173
+
174
  if "lora_down" in k or "lora_A" in k:
175
+ pairs[stem]["down"] = tensor_low
176
  pairs[stem]["rank"] = v.shape[0]
177
  elif "lora_up" in k or "lora_B" in k:
178
+ pairs[stem]["up"] = tensor_low
179
 
180
  for stem in pairs:
181
  if stem in alphas:
 
188
 
189
  return pairs
190
 
191
+ def merge_shard_logic(base_path, lora_pairs, scale, output_path, precision_dtype=torch.bfloat16):
192
  print(f"Loading base shard: {base_path}")
 
193
  base_state = load_file(base_path, device="cpu")
194
 
195
  lora_keys = set(lora_pairs.keys())
196
  keys_to_process = list(base_state.keys())
197
 
198
  for k in keys_to_process:
199
+ # Don't detach v yet, we modify in place
200
  v = base_state[k]
201
  base_stem = get_key_stem(k)
202
  match = None
 
217
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
218
 
219
  if match and "down" in match and "up" in match:
220
+ # Weights are already in precision_dtype from load_lora_to_memory
221
  down = match["down"]
222
  up = match["up"]
223
  alpha = match["alpha"]
 
230
  down = down.unsqueeze(-1).unsqueeze(-1)
231
  up = up.unsqueeze(-1).unsqueeze(-1)
232
 
233
+ # Compute Delta in Low Precision
234
  try:
235
  if len(up.shape) == 4:
236
  delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
 
241
 
242
  delta = delta * scaling
243
 
 
244
  valid_delta = True
245
 
246
+ # --- Dynamic Reshaping / Slicing ---
247
  if delta.shape == v.shape:
248
  pass
249
  elif delta.shape[0] == v.shape[0] * 3:
 
284
  save_file(base_state, output_path)
285
  return True
286
 
287
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, precision, progress=gr.Progress()):
288
  cleanup_temp()
289
  login(hf_token)
290
+
291
+ # Determine Dtype
292
  if precision == "bf16":
293
  dtype = torch.bfloat16
294
  elif precision == "fp16":
 
297
  dtype = torch.float32
298
 
299
  print(f"Selected Precision: {dtype}")
300
+
301
  try:
302
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
303
  except Exception as e:
 
335
  progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
336
  local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
337
  merged_path = TempDir / "merged.safetensors"
338
+
339
+ # Pass precision preference
340
  merge_shard_logic(local_shard, lora_pairs, scale, merged_path, precision_dtype=dtype)
341
 
 
342
  api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
343
 
 
344
  os.remove(local_shard)
345
  if merged_path.exists(): os.remove(merged_path)
346
  gc.collect()