AlekseyCalvin commited on
Commit
029e89b
·
verified ·
1 Parent(s): 584e440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -117
app.py CHANGED
@@ -229,63 +229,177 @@ class ShardBuffer:
229
  self.current_bytes = 0
230
  gc.collect()
231
 
 
 
 
 
232
  def download_lora_smart(input_str, token):
233
- """
234
- Handles Repo IDs (user/repo) and Direct URLs.
235
- """
236
  local_path = TempDir / "adapter.safetensors"
237
-
238
- # 1. Direct URL (Private/Public)
 
239
  if input_str.startswith("http"):
240
  print(f"Downloading LoRA from URL: {input_str}")
241
  headers = {"Authorization": f"Bearer {token}"} if token else {}
242
  try:
243
- response = requests.get(input_str, stream=True, headers=headers, timeout=30)
244
  response.raise_for_status()
245
  with open(local_path, 'wb') as f:
246
  for chunk in response.iter_content(chunk_size=8192):
247
  f.write(chunk)
248
- # Basic validation
249
- with open(local_path, "rb") as f:
250
- if len(f.read(8)) == 8: return local_path
251
  except Exception as e:
252
  print(f"URL download failed: {e}. Trying as Repo ID...")
253
 
254
- # 2. Repo ID (Fallback or Primary)
255
- # If the user entered a repo ID (e.g. "AlekseyCalvin/MyLora"), this catches it.
256
  print(f"Attempting download from Hub Repo: {input_str}")
257
  try:
258
- # Try finding the specific file
 
 
 
 
 
 
259
  candidates = ["adapter_model.safetensors", "model.safetensors"]
260
- target_file = None
261
 
262
- try:
263
- files = list_repo_files(repo_id=input_str, token=token)
264
- safetensors = [f for f in files if f.endswith(".safetensors")]
265
- for c in candidates:
266
- if c in safetensors:
267
- target_file = c
268
- break
269
- if not target_file and safetensors:
270
- target_file = safetensors[0]
271
- except:
272
- # If listing fails, try default
273
- target_file = "adapter_model.safetensors"
274
-
275
- hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
276
 
277
- # Rename to generic name
278
- downloaded = TempDir / target_file
 
 
279
  if downloaded != local_path:
280
- if local_path.exists(): os.remove(local_path)
281
  shutil.move(downloaded, local_path)
282
 
283
  return local_path
284
  except Exception as e:
285
- raise ValueError(f"Failed to download LoRA from {input_str}. \nError: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
288
  cleanup_temp()
 
 
289
  login(hf_token)
290
 
291
  # 1. Output Setup
@@ -293,151 +407,209 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
293
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
294
  except Exception as e: return f"Error creating repo: {e}"
295
 
296
- # Define modes
297
- output_subfolder = base_subfolder if base_subfolder else ""
 
 
 
 
298
 
299
- # 2. Clone Structure
300
- if structure_repo:
301
- print(f"Cloning structure from {structure_repo}...")
302
- # Ignore the folder we are overwriting (if any)
303
- ignore = output_subfolder if output_subfolder else None
304
- # Root merge mode (LLM) usually implies we skip weights in the root
305
- is_root_merge = not bool(output_subfolder)
306
- streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=is_root_merge)
307
-
308
- # 3. Download Input Shards
309
- progress(0.1, desc="Downloading Base Model...")
310
- try:
311
- files = list_repo_files(repo_id=base_repo, token=hf_token)
312
- except Exception as e: return f"Error accessing base repo: {e}"
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  input_shards = []
 
315
  for f in files:
316
  if f.endswith(".safetensors"):
317
- # Filter by subfolder if specified
318
- if output_subfolder and not f.startswith(output_subfolder): continue
319
 
320
- local_path = TempDir / "input_shards" / os.path.basename(f)
321
- os.makedirs(local_path.parent, exist_ok=True)
 
322
 
323
- hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local_path.parent, local_dir_use_symlinks=False)
324
-
325
- # Locate file (handle nested download paths)
326
- found = list(local_path.parent.rglob(os.path.basename(f)))
327
  if found: input_shards.append(found[0])
328
 
329
- if not input_shards: return "No base safetensors found in specified location."
330
  input_shards.sort()
331
-
332
- # --- NAMING CONVENTION LOGIC ---
333
- # 1. Check for Diffusers specific subfolders -> force 'diffusion_pytorch_model'
334
- if output_subfolder in ["transformer", "unet"]:
335
- filename_prefix = "diffusion_pytorch_model"
336
- index_filename = "diffusion_pytorch_model.safetensors.index.json"
337
- # 2. Check input file naming -> adopt input convention
338
- elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]):
339
- filename_prefix = "diffusion_pytorch_model"
340
- index_filename = "diffusion_pytorch_model.safetensors.index.json"
341
- # 3. Default to LLM style
342
  else:
343
- filename_prefix = "model"
344
- index_filename = "model.safetensors.index.json"
345
-
346
- print(f"Naming scheme: {filename_prefix} (Index: {index_filename})")
347
 
348
- # 4. Load LoRA
 
 
 
 
 
349
  dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
350
  try:
351
- progress(0.15, desc="Downloading LoRA...")
352
  lora_path = download_lora_smart(lora_input, hf_token)
353
  lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
354
- except Exception as e: return f"Error loading LoRA: {e}"
355
 
356
- # 5. Stream Process
357
- buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
358
 
359
- for i, shard_file in enumerate(input_shards):
360
- progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
 
361
 
362
- with MemoryEfficientSafeOpen(shard_file) as f:
363
- keys = f.keys()
364
- for k in keys:
365
- v = f.get_tensor(k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  base_stem = get_key_stem(k)
367
- lora_keys = set(lora_pairs.keys())
368
  match = None
369
 
370
- if base_stem in lora_keys: match = lora_pairs[base_stem]
371
- # QKV Heuristics (Z-Image/Flux specific)
 
372
  if not match:
373
  if "to_q" in base_stem:
374
- qkv_stem = base_stem.replace("to_q", "qkv")
375
- if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
376
  elif "to_k" in base_stem:
377
- qkv_stem = base_stem.replace("to_k", "qkv")
378
- if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
379
  elif "to_v" in base_stem:
380
- qkv_stem = base_stem.replace("to_v", "qkv")
381
- if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
382
 
383
- if match and "down" in match and "up" in match:
384
  down = match["down"]
385
  up = match["up"]
 
386
  scaling = scale * (match["alpha"] / match["rank"])
387
-
388
  if len(v.shape) == 4 and len(down.shape) == 2:
389
  down = down.unsqueeze(-1).unsqueeze(-1)
390
  up = up.unsqueeze(-1).unsqueeze(-1)
391
-
392
  try:
393
  if len(up.shape) == 4:
394
  delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
395
  else:
396
  delta = up @ down
397
- except:
398
- delta = up.T @ down
399
 
400
  delta = delta * scaling
401
- valid_delta = True
402
 
 
 
403
  if delta.shape == v.shape: pass
404
  elif delta.shape[0] == v.shape[0] * 3:
405
  chunk = v.shape[0]
406
  if "to_q" in k: delta = delta[0:chunk, ...]
407
  elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
408
  elif "to_v" in k: delta = delta[2*chunk:, ...]
409
- else: valid_delta = False
410
- elif delta.numel() == v.numel():
411
- delta = delta.reshape(v.shape)
412
- else: valid_delta = False
413
-
414
- if valid_delta:
415
  v = v.to(dtype)
416
  delta = delta.to(dtype)
417
  v.add_(delta)
418
  del delta
419
-
 
 
420
  if v.dtype != dtype: v = v.to(dtype)
421
- buffer.add_tensor(k, v)
 
 
 
 
422
  del v
423
-
424
- os.remove(shard_file)
 
 
 
 
 
 
425
  gc.collect()
426
 
427
- buffer.flush()
428
-
429
- # 6. Upload Index (Now using correct total_size)
430
- print(f"Uploading Index: {index_filename} (Total Size: {buffer.total_size})")
431
- index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
432
-
433
- with open(TempDir / index_filename, "w") as f:
 
 
 
 
434
  json.dump(index_data, f, indent=4)
435
 
436
- path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
437
- api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
438
 
439
  cleanup_temp()
440
- return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
441
 
442
  # =================================================================================
443
  # TAB 2: EXTRACT LORA
 
229
  self.current_bytes = 0
230
  gc.collect()
231
 
232
+ # =================================================================================
233
+ # ROBUST RESHARDING LOGIC (Plan -> Execute)
234
+ # =================================================================================
235
+
236
  def download_lora_smart(input_str, token):
237
+ """Robust LoRA downloader that handles Direct URLs and Repo IDs."""
 
 
238
  local_path = TempDir / "adapter.safetensors"
239
+ if local_path.exists(): os.remove(local_path)
240
+
241
+ # 1. Try as Direct URL
242
  if input_str.startswith("http"):
243
  print(f"Downloading LoRA from URL: {input_str}")
244
  headers = {"Authorization": f"Bearer {token}"} if token else {}
245
  try:
246
+ response = requests.get(input_str, stream=True, headers=headers, timeout=60)
247
  response.raise_for_status()
248
  with open(local_path, 'wb') as f:
249
  for chunk in response.iter_content(chunk_size=8192):
250
  f.write(chunk)
251
+ if verify_safetensors(local_path): return local_path
 
 
252
  except Exception as e:
253
  print(f"URL download failed: {e}. Trying as Repo ID...")
254
 
255
+ # 2. Try as Repo ID
 
256
  print(f"Attempting download from Hub Repo: {input_str}")
257
  try:
258
+ # Check if user provided a filename in the repo string (e.g. user/repo/file.safetensors)
259
+ if ".safetensors" in input_str and "/" in input_str:
260
+ # splitting repo_id and filename might be needed, but hf_hub_download expects valid repo_id
261
+ pass
262
+
263
+ # Try to find the adapter file automatically
264
+ files = list_repo_files(repo_id=input_str, token=token)
265
  candidates = ["adapter_model.safetensors", "model.safetensors"]
266
+ target = next((f for f in files if f in candidates), None)
267
 
268
+ # If no standard name, take the first safetensors found
269
+ if not target:
270
+ safes = [f for f in files if f.endswith(".safetensors")]
271
+ if safes: target = safes[0]
272
+
273
+ if not target: raise ValueError("No .safetensors found")
 
 
 
 
 
 
 
 
274
 
275
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
276
+
277
+ # Move to standard location
278
+ downloaded = TempDir / target
279
  if downloaded != local_path:
 
280
  shutil.move(downloaded, local_path)
281
 
282
  return local_path
283
  except Exception as e:
284
+ raise ValueError(f"Could not download LoRA. Checked URL and Repo. Error: {e}")
285
+
286
+ def get_tensor_byte_size(shape, dtype_str):
287
+ """Calculates byte size of a tensor based on shape and dtype."""
288
+ # F32=4, F16/BF16=2, I8=1, etc.
289
+ bytes_per = 4 if "F32" in dtype_str else 2 if "16" in dtype_str else 1
290
+ numel = 1
291
+ for d in shape: numel *= d
292
+ return numel * bytes_per
293
+
294
+ def plan_resharding(input_shards, max_shard_size_gb, filename_prefix):
295
+ """
296
+ Pass 1: Reads headers ONLY. Groups tensors into virtual shards of max_shard_size_gb.
297
+ Returns a Plan (List of ShardDefinitions).
298
+ """
299
+ print(f"Planning resharding (Max {max_shard_size_gb} GB)...")
300
+ max_bytes = int(max_shard_size_gb * 1024**3)
301
+
302
+ all_tensors = []
303
+
304
+ # 1. Scan all inputs
305
+ for p in input_shards:
306
+ with MemoryEfficientSafeOpen(p) as f:
307
+ for k in f.keys():
308
+ shape = f.header[k]['shape']
309
+ dtype = f.header[k]['dtype']
310
+ size = get_tensor_byte_size(shape, dtype)
311
+ all_tensors.append({
312
+ "key": k,
313
+ "shape": shape,
314
+ "dtype": dtype,
315
+ "size": size,
316
+ "source": p
317
+ })
318
+
319
+ # 2. Sort tensors (Crucial for deterministic output)
320
+ all_tensors.sort(key=lambda x: x["key"])
321
+
322
+ # 3. Bucket into Shards
323
+ plan = []
324
+ current_shard = []
325
+ current_size = 0
326
+
327
+ for t in all_tensors:
328
+ # If adding this tensor exceeds limit AND we have stuff in the bucket, close bucket
329
+ if current_size + t['size'] > max_bytes and current_shard:
330
+ plan.append(current_shard)
331
+ current_shard = []
332
+ current_size = 0
333
+
334
+ current_shard.append(t)
335
+ current_size += t['size']
336
+
337
+ if current_shard:
338
+ plan.append(current_shard)
339
+
340
+ total_shards = len(plan)
341
+ total_model_size = sum(t['size'] for shard in plan for t in shard)
342
+
343
+ print(f"Plan created: {total_shards} shards. Total size: {total_model_size / 1024**3:.2f} GB")
344
+
345
+ # 4. Format Plan
346
+ final_plan = []
347
+ for i, shard_tensors in enumerate(plan):
348
+ # Naming: prefix-00001-of-00005.safetensors
349
+ name = f"{filename_prefix}-{i+1:05d}-of-{total_shards:05d}.safetensors"
350
+ final_plan.append({
351
+ "filename": name,
352
+ "tensors": shard_tensors
353
+ })
354
+
355
+ return final_plan, total_model_size
356
+
357
+ def copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder):
358
+ """
359
+ Downloads NON-WEIGHT files (json, txt, model) from Base Repo and uploads to Output.
360
+ """
361
+ print(f"Copying config files from {base_repo}...")
362
+ try:
363
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
364
+
365
+ # Extensions to KEEP (Configs, Tokenizers, etc.)
366
+ allowed_ext = ['.json', '.txt', '.model', '.py', '.yml', '.yaml']
367
+ # Extensions to SKIP (Weights, we are generating these)
368
+ blocked_ext = ['.safetensors', '.bin', '.pt', '.pth', '.msgpack', '.h5']
369
+
370
+ for f in files:
371
+ # Filter by subfolder if needed
372
+ if base_subfolder and not f.startswith(base_subfolder):
373
+ continue
374
+
375
+ ext = os.path.splitext(f)[1]
376
+ if ext in blocked_ext: continue
377
+ if ext not in allowed_ext: continue # Skip unknown types to be safe? Or allow?
378
+
379
+ # Download
380
+ print(f"Transferring {f}...")
381
+ local = hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=TempDir)
382
+
383
+ # Determine path in new repo
384
+ if base_subfolder:
385
+ # Remove base_subfolder prefix for the rel path
386
+ rel_name = f[len(base_subfolder):].lstrip('/')
387
+ else:
388
+ rel_name = f
389
+
390
+ # Add output_subfolder prefix
391
+ target_path = f"{output_subfolder}/{rel_name}" if output_subfolder else rel_name
392
+
393
+ api.upload_file(path_or_fileobj=local, path_in_repo=target_path, repo_id=output_repo, token=hf_token)
394
+ os.remove(local)
395
+
396
+ except Exception as e:
397
+ print(f"Config copy warning: {e}")
398
 
399
  def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
400
  cleanup_temp()
401
+
402
+ if not hf_token: return "Error: Token missing."
403
  login(hf_token)
404
 
405
  # 1. Output Setup
 
407
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
408
  except Exception as e: return f"Error creating repo: {e}"
409
 
410
+ # Determine Folder Logic
411
+ # If base_subfolder is "qint4", and we want output to be "transformer", user needs to specify that.
412
+ # But usually, if base has a subfolder, we maintain a subfolder structure.
413
+ # ADAPTIVE: If base_subfolder is "qint4", we treat it as the source of weights.
414
+ # Since you merged into "transformer", I assume you want the output in "transformer".
415
+ # For general LLMs (root), both are empty.
416
 
417
+ # Heuristic: If base has subfolder, use "transformer" as target if it looks like a DiT, else keep original name.
418
+ if base_subfolder:
419
+ output_subfolder = "transformer" if "qint" in base_subfolder or "transformer" in base_subfolder else base_subfolder
420
+ else:
421
+ output_subfolder = ""
 
 
 
 
 
 
 
 
 
422
 
423
+ # 2. Copy Configs (The missing step from previous run)
424
+ copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder)
425
+
426
+ # 3. Structure Repo (Only needed if Base doesn't have everything, e.g. VAE)
427
+ if structure_repo:
428
+ print(f"Copying extras from {structure_repo}...")
429
+ # We assume structure repo is a standard diffusers repo
430
+ # We copy text_encoder, vae, scheduler, tokenizer, etc.
431
+ # We SKIP 'transformer' or 'unet' because we are building that.
432
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix="transformer")
433
+
434
+ # 4. Download ALL Input Shards (Needed for Planning)
435
+ progress(0.1, desc="Downloading Input Model...")
436
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
437
  input_shards = []
438
+
439
  for f in files:
440
  if f.endswith(".safetensors"):
441
+ if base_subfolder and not f.startswith(base_subfolder): continue
 
442
 
443
+ local = TempDir / "inputs" / os.path.basename(f)
444
+ os.makedirs(local.parent, exist_ok=True)
445
+ hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False)
446
 
447
+ # Handle nesting
448
+ found = list(local.parent.rglob(os.path.basename(f)))
 
 
449
  if found: input_shards.append(found[0])
450
 
451
+ if not input_shards: return "No safetensors found."
452
  input_shards.sort()
453
+
454
+ # 5. Detect Naming Convention (Adaptive)
455
+ sample_name = os.path.basename(input_shards[0])
456
+ if "diffusion_pytorch_model" in sample_name or output_subfolder == "transformer":
457
+ prefix = "diffusion_pytorch_model"
458
+ index_file = "diffusion_pytorch_model.safetensors.index.json"
 
 
 
 
 
459
  else:
460
+ prefix = "model"
461
+ index_file = "model.safetensors.index.json"
 
 
462
 
463
+ # 6. Create Plan (Pass 1)
464
+ # This calculates total shards and size BEFORE processing
465
+ progress(0.2, desc="Planning Shards...")
466
+ plan, total_model_size = plan_resharding(input_shards, shard_size, prefix)
467
+
468
+ # 7. Load LoRA
469
  dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
470
  try:
471
+ progress(0.25, desc="Loading LoRA...")
472
  lora_path = download_lora_smart(lora_input, hf_token)
473
  lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
474
+ except Exception as e: return f"LoRA Error: {e}"
475
 
476
+ # 8. Execute Plan (Pass 2)
477
+ index_map = {}
478
 
479
+ for i, shard_plan in enumerate(plan):
480
+ filename = shard_plan['filename']
481
+ tensors_to_write = shard_plan['tensors']
482
 
483
+ progress(0.3 + (0.7 * i / len(plan)), desc=f"Merging {filename}")
484
+ print(f"Generating {filename} ({len(tensors_to_write)} tensors)...")
485
+
486
+ # Prepare Header
487
+ header = {"__metadata__": {"format": "pt"}}
488
+ current_offset = 0
489
+ for t in tensors_to_write:
490
+ # Recalculate dtype string for header based on TARGET dtype
491
+ tgt_dtype_str = "BF16" if dtype == torch.bfloat16 else "F16" if dtype == torch.float16 else "F32"
492
+
493
+ # Calculate output size (might differ from input size if we change precision)
494
+ # Input size in plan was source size. We need target size.
495
+ out_size = get_tensor_byte_size(t['shape'], tgt_dtype_str)
496
+
497
+ header[t['key']] = {
498
+ "dtype": tgt_dtype_str,
499
+ "shape": t['shape'],
500
+ "data_offsets": [current_offset, current_offset + out_size]
501
+ }
502
+ current_offset += out_size
503
+ index_map[t['key']] = filename
504
+
505
+ header_json = json.dumps(header).encode('utf-8')
506
+
507
+ out_path = TempDir / filename
508
+ with open(out_path, 'wb') as f_out:
509
+ f_out.write(struct.pack('<Q', len(header_json)))
510
+ f_out.write(header_json)
511
+
512
+ # Open source files as needed
513
+ open_files = {}
514
+
515
+ for t_plan in tqdm(tensors_to_write, leave=False):
516
+ src = t_plan['source']
517
+ if src not in open_files: open_files[src] = MemoryEfficientSafeOpen(src)
518
+
519
+ # Load Tensor
520
+ v = open_files[src].get_tensor(t_plan['key'])
521
+ k = t_plan['key']
522
+
523
+ # --- MERGE LOGIC ---
524
  base_stem = get_key_stem(k)
 
525
  match = None
526
 
527
+ # Check match (Same logic as before)
528
+ if base_stem in lora_pairs: match = lora_pairs[base_stem]
529
+ # ... [QKV Logic omitted for brevity, same as previous] ...
530
  if not match:
531
  if "to_q" in base_stem:
532
+ qkv = base_stem.replace("to_q", "qkv")
533
+ if qkv in lora_pairs: match = lora_pairs[qkv]
534
  elif "to_k" in base_stem:
535
+ qkv = base_stem.replace("to_k", "qkv")
536
+ if qkv in lora_pairs: match = lora_pairs[qkv]
537
  elif "to_v" in base_stem:
538
+ qkv = base_stem.replace("to_v", "qkv")
539
+ if qkv in lora_pairs: match = lora_pairs[qkv]
540
 
541
+ if match:
542
  down = match["down"]
543
  up = match["up"]
544
+ # ... [Matmul Logic, same as previous] ...
545
  scaling = scale * (match["alpha"] / match["rank"])
 
546
  if len(v.shape) == 4 and len(down.shape) == 2:
547
  down = down.unsqueeze(-1).unsqueeze(-1)
548
  up = up.unsqueeze(-1).unsqueeze(-1)
 
549
  try:
550
  if len(up.shape) == 4:
551
  delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
552
  else:
553
  delta = up @ down
554
+ except: delta = up.T @ down
 
555
 
556
  delta = delta * scaling
 
557
 
558
+ # Slicing
559
+ valid = True
560
  if delta.shape == v.shape: pass
561
  elif delta.shape[0] == v.shape[0] * 3:
562
  chunk = v.shape[0]
563
  if "to_q" in k: delta = delta[0:chunk, ...]
564
  elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
565
  elif "to_v" in k: delta = delta[2*chunk:, ...]
566
+ else: valid = False
567
+ elif delta.numel() == v.numel(): delta = delta.reshape(v.shape)
568
+ else: valid = False
569
+
570
+ if valid:
 
571
  v = v.to(dtype)
572
  delta = delta.to(dtype)
573
  v.add_(delta)
574
  del delta
575
+ # --- END MERGE ---
576
+
577
+ # Write
578
  if v.dtype != dtype: v = v.to(dtype)
579
+ if dtype == torch.bfloat16:
580
+ raw = v.view(torch.int16).numpy().tobytes()
581
+ else:
582
+ raw = v.numpy().tobytes()
583
+ f_out.write(raw)
584
  del v
585
+
586
+ # Close handles
587
+ for fh in open_files.values(): fh.file.close()
588
+
589
+ # Upload Shard
590
+ path_in_repo = f"{output_subfolder}/{filename}" if output_subfolder else filename
591
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
592
+ os.remove(out_path)
593
  gc.collect()
594
 
595
+ # 9. Upload Index
596
+ # Update total size to reflect the TARGET dtype size, not source
597
+ # We recalculate total_size based on what we actually wrote
598
+ final_total_size = 0
599
+ for t_list in plan:
600
+ for t in t_list['tensors']:
601
+ tgt_dtype_str = "BF16" if dtype == torch.bfloat16 else "F16" if dtype == torch.float16 else "F32"
602
+ final_total_size += get_tensor_byte_size(t['shape'], tgt_dtype_str)
603
+
604
+ index_data = {"metadata": {"total_size": final_total_size}, "weight_map": index_map}
605
+ with open(TempDir / index_file, "w") as f:
606
  json.dump(index_data, f, indent=4)
607
 
608
+ path_in_repo = f"{output_subfolder}/{index_file}" if output_subfolder else index_file
609
+ api.upload_file(path_or_fileobj=TempDir / index_file, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
610
 
611
  cleanup_temp()
612
+ return f"Success! {len(plan)} shards created at {output_repo}"
613
 
614
  # =================================================================================
615
  # TAB 2: EXTRACT LORA