AlekseyCalvin commited on
Commit
f1167d3
·
verified ·
1 Parent(s): 5af1d7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -239,27 +239,21 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
239
  valid_delta = False
240
 
241
  if valid_delta:
242
- # IN-PLACE MERGE to save memory
243
- # 1. Promote to float32
244
- # 2. Add delta
245
- # 3. Cast back to original dtype
246
- # 4. Replace in dict
247
- orig_dtype = v.dtype
248
 
249
- # Perform add in float32 to avoid overflow/precision issues
250
- # Create temp float tensor
251
- v_float = v.to(torch.float32)
252
- v_float.add_(delta) # In-place add
253
 
254
- # Cast back and replace in dict
255
- base_state[k] = v_float.to(orig_dtype)
256
 
257
  # Explicit cleanup
258
- del v_float
259
  del delta
260
- # del v # v is a reference to base_state[k], which we just overwrote
261
 
262
- # Periodic GC to prevent fragmentation OOM
263
  if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
264
  gc.collect()
265
 
@@ -269,6 +263,16 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
269
  def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
270
  cleanup_temp()
271
  login(hf_token)
 
 
 
 
 
 
 
 
 
 
272
 
273
  try:
274
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
@@ -288,9 +292,13 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_re
288
  except Exception as e:
289
  print(f"Structure clone warning: {e}")
290
 
291
- progress(0.1, desc="Loading LoRA...")
292
- lora_path = download_file(lora_input, hf_token)
293
- lora_pairs = load_lora_to_memory(lora_path)
 
 
 
 
294
 
295
  files = list_repo_files(repo_id=base_repo, token=hf_token)
296
  shards = [f for f in files if f.endswith(".safetensors")]
@@ -303,9 +311,9 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_re
303
  progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
304
  local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
305
  merged_path = TempDir / "merged.safetensors"
306
-
307
- # Merge Logic
308
- merge_shard_logic(local_shard, lora_pairs, scale, merged_path)
309
 
310
  # Upload
311
  api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
 
239
  valid_delta = False
240
 
241
  if valid_delta:
242
+ # Optimized In-Place Addition
243
+ # We do NOT cast base to float32. We trust bf16/fp16 is sufficient for merging.
 
 
 
 
244
 
245
+ # If base is float32 (rare for new models), we respect it.
246
+ # If base is bf16, we add bf16 delta.
247
+ if v.dtype != delta.dtype:
248
+ delta = delta.to(v.dtype)
249
 
250
+ # In-place add
251
+ v.add_(delta)
252
 
253
  # Explicit cleanup
 
254
  del delta
 
255
 
256
+ # Periodic GC
257
  if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
258
  gc.collect()
259
 
 
263
  def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
264
  cleanup_temp()
265
  login(hf_token)
266
+
267
+ # Determine Dtype
268
+ if precision == "bf16":
269
+ dtype = torch.bfloat16
270
+ elif precision == "fp16":
271
+ dtype = torch.float16
272
+ else:
273
+ dtype = torch.float32
274
+
275
+ print(f"Selected Precision: {dtype}")
276
 
277
  try:
278
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
 
292
  except Exception as e:
293
  print(f"Structure clone warning: {e}")
294
 
295
+ try:
296
+ progress(0.1, desc="Downloading LoRA...")
297
+ lora_path = download_file(lora_input, hf_token)
298
+ # Load LoRA in target precision to save RAM immediately
299
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
300
+ except Exception as e:
301
+ return f"CRITICAL ERROR: {str(e)}"
302
 
303
  files = list_repo_files(repo_id=base_repo, token=hf_token)
304
  shards = [f for f in files if f.endswith(".safetensors")]
 
311
  progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
312
  local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
313
  merged_path = TempDir / "merged.safetensors"
314
+
315
+ # Pass precision preference
316
+ merge_shard_logic(local_shard, lora_pairs, scale, merged_path, precision_dtype=dtype)
317
 
318
  # Upload
319
  api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)