Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -239,27 +239,21 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
|
|
| 239 |
valid_delta = False
|
| 240 |
|
| 241 |
if valid_delta:
|
| 242 |
-
#
|
| 243 |
-
#
|
| 244 |
-
# 2. Add delta
|
| 245 |
-
# 3. Cast back to original dtype
|
| 246 |
-
# 4. Replace in dict
|
| 247 |
-
orig_dtype = v.dtype
|
| 248 |
|
| 249 |
-
#
|
| 250 |
-
#
|
| 251 |
-
|
| 252 |
-
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
|
| 257 |
# Explicit cleanup
|
| 258 |
-
del v_float
|
| 259 |
del delta
|
| 260 |
-
# del v # v is a reference to base_state[k], which we just overwrote
|
| 261 |
|
| 262 |
-
# Periodic GC
|
| 263 |
if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
|
| 264 |
gc.collect()
|
| 265 |
|
|
@@ -269,6 +263,16 @@ def merge_shard_logic(base_path, lora_pairs, scale, output_path):
|
|
| 269 |
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
|
| 270 |
cleanup_temp()
|
| 271 |
login(hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
try:
|
| 274 |
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
|
|
@@ -288,9 +292,13 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_re
|
|
| 288 |
except Exception as e:
|
| 289 |
print(f"Structure clone warning: {e}")
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
files = list_repo_files(repo_id=base_repo, token=hf_token)
|
| 296 |
shards = [f for f in files if f.endswith(".safetensors")]
|
|
@@ -303,9 +311,9 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_re
|
|
| 303 |
progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
|
| 304 |
local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
|
| 305 |
merged_path = TempDir / "merged.safetensors"
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
merge_shard_logic(local_shard, lora_pairs, scale, merged_path)
|
| 309 |
|
| 310 |
# Upload
|
| 311 |
api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
|
|
|
|
| 239 |
valid_delta = False
|
| 240 |
|
| 241 |
if valid_delta:
|
| 242 |
+
# Optimized In-Place Addition
|
| 243 |
+
# We do NOT cast base to float32. We trust bf16/fp16 is sufficient for merging.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
# If base is float32 (rare for new models), we respect it.
|
| 246 |
+
# If base is bf16, we add bf16 delta.
|
| 247 |
+
if v.dtype != delta.dtype:
|
| 248 |
+
delta = delta.to(v.dtype)
|
| 249 |
|
| 250 |
+
# In-place add
|
| 251 |
+
v.add_(delta)
|
| 252 |
|
| 253 |
# Explicit cleanup
|
|
|
|
| 254 |
del delta
|
|
|
|
| 255 |
|
| 256 |
+
# Periodic GC
|
| 257 |
if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
|
| 258 |
gc.collect()
|
| 259 |
|
|
|
|
| 263 |
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()):
|
| 264 |
cleanup_temp()
|
| 265 |
login(hf_token)
|
| 266 |
+
|
| 267 |
+
# Determine Dtype
|
| 268 |
+
if precision == "bf16":
|
| 269 |
+
dtype = torch.bfloat16
|
| 270 |
+
elif precision == "fp16":
|
| 271 |
+
dtype = torch.float16
|
| 272 |
+
else:
|
| 273 |
+
dtype = torch.float32
|
| 274 |
+
|
| 275 |
+
print(f"Selected Precision: {dtype}")
|
| 276 |
|
| 277 |
try:
|
| 278 |
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
|
|
|
|
| 292 |
except Exception as e:
|
| 293 |
print(f"Structure clone warning: {e}")
|
| 294 |
|
| 295 |
+
try:
|
| 296 |
+
progress(0.1, desc="Downloading LoRA...")
|
| 297 |
+
lora_path = download_file(lora_input, hf_token)
|
| 298 |
+
# Load LoRA in target precision to save RAM immediately
|
| 299 |
+
lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
|
| 300 |
+
except Exception as e:
|
| 301 |
+
return f"CRITICAL ERROR: {str(e)}"
|
| 302 |
|
| 303 |
files = list_repo_files(repo_id=base_repo, token=hf_token)
|
| 304 |
shards = [f for f in files if f.endswith(".safetensors")]
|
|
|
|
| 311 |
progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
|
| 312 |
local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
|
| 313 |
merged_path = TempDir / "merged.safetensors"
|
| 314 |
+
|
| 315 |
+
# Pass precision preference
|
| 316 |
+
merge_shard_logic(local_shard, lora_pairs, scale, merged_path, precision_dtype=dtype)
|
| 317 |
|
| 318 |
# Upload
|
| 319 |
api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
|