AlekseyCalvin commited on
Commit
4419bdb
·
verified ·
1 Parent(s): a690cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -100
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  import os
@@ -154,14 +156,15 @@ class ShardBuffer:
154
  self.output_repo = output_repo
155
  self.subfolder = subfolder
156
  self.hf_token = hf_token
157
- self.filename_prefix = filename_prefix # Dynamic prefix (e.g. 'diffusion_pytorch_model' or 'model')
158
  self.buffer = []
159
  self.current_bytes = 0
160
  self.shard_count = 0
161
  self.index_map = {}
162
- self.total_model_size = 0
163
 
164
  def add_tensor(self, key, tensor):
 
165
  if tensor.dtype == torch.bfloat16:
166
  raw_bytes = tensor.view(torch.int16).numpy().tobytes()
167
  dtype_str = "BF16"
@@ -173,14 +176,16 @@ class ShardBuffer:
173
  dtype_str = "F32"
174
 
175
  size = len(raw_bytes)
 
176
  self.buffer.append({
177
  "key": key,
178
  "data": raw_bytes,
179
  "dtype": dtype_str,
180
  "shape": tensor.shape
181
  })
 
182
  self.current_bytes += size
183
- self.total_model_size += size
184
 
185
  if self.current_bytes >= self.max_bytes:
186
  self.flush()
@@ -189,7 +194,8 @@ class ShardBuffer:
189
  if not self.buffer: return
190
  self.shard_count += 1
191
 
192
- # ADAPTIVE NAMING: Uses the prefix detected from the base model
 
193
  filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
194
 
195
  # Proper Subfolder Handling
@@ -206,7 +212,7 @@ class ShardBuffer:
206
  "data_offsets": [current_offset, current_offset + len(item["data"])]
207
  }
208
  current_offset += len(item["data"])
209
- self.index_map[item["key"]] = filename
210
 
211
  header_json = json.dumps(header).encode('utf-8')
212
 
@@ -225,44 +231,60 @@ class ShardBuffer:
225
  self.current_bytes = 0
226
  gc.collect()
227
 
228
- def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix="transformer"):
229
  """
230
- Copies files one-by-one from source to dest, skipping 'ignore_prefix'.
231
- Does NOT skip .safetensors/.bin if they are outside the ignore folder.
232
  """
233
- print(f"Scanning {src_repo} for auxiliary files...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  try:
235
- files = api.list_repo_files(repo_id=src_repo, token=token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- for f in tqdm(files, desc="Copying Structure"):
238
- # 1. Skip the folder we are replacing (e.g., transformer/)
239
- if ignore_prefix and f.startswith(ignore_prefix):
240
- continue
 
241
 
242
- # 2. Skip hidden/system files
243
- if f.startswith("."):
244
- continue
245
-
246
- # 3. Download -> Upload -> Delete loop
247
- # This ensures we get VAE/TextEnc weights without disk overflow
248
- try:
249
- print(f"Copying {f}...")
250
- local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
251
-
252
- api.upload_file(
253
- path_or_fileobj=local,
254
- path_in_repo=f,
255
- repo_id=dst_repo,
256
- token=token
257
- )
258
-
259
- if os.path.exists(local):
260
- os.remove(local)
261
- except Exception as e:
262
- print(f"Failed to copy {f}: {e}")
263
-
264
  except Exception as e:
265
- print(f"Structure cloning error: {e}")
266
 
267
  def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
268
  cleanup_temp()
@@ -273,62 +295,73 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
273
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
274
  except Exception as e: return f"Error creating repo: {e}"
275
 
276
- # 2. Server-Side Structure Clone
 
 
 
277
  if structure_repo:
278
- ignore = base_subfolder if base_subfolder else None
279
- streaming_copy_structure(hf_token, structure_repo, output_repo, ignore)
280
-
281
- # 3. Load LoRA
282
- dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
 
 
 
 
283
  try:
284
- progress(0.1, desc="Downloading LoRA...")
285
- lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
286
- lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
287
- except Exception as e: return f"Error loading LoRA: {e}"
288
 
289
- # 4. Stream Process
290
- progress(0.2, desc="Fetching File List...")
291
- files = list_repo_files(repo_id=base_repo, token=hf_token)
292
-
293
- # Identify valid shards in the target folder
294
  input_shards = []
295
  for f in files:
296
- if not f.endswith(".safetensors"): continue
297
- if base_subfolder and not f.startswith(base_subfolder): continue
298
- input_shards.append(f)
299
-
 
 
 
 
 
 
 
 
 
300
  if not input_shards: return "No base safetensors found in specified location."
301
-
302
  input_shards.sort()
303
 
304
- # --- AUTO-DETECT NAMING CONVENTION ---
305
- # We look at the first file to decide the naming scheme.
306
- # Common schemes:
307
- # "diffusion_pytorch_model-00001..." -> prefix: "diffusion_pytorch_model"
308
- # "model-00001..." -> prefix: "model"
309
- # "model.safetensors" -> prefix: "model"
310
-
311
- first_file = os.path.basename(input_shards[0])
312
-
313
- if first_file.startswith("diffusion_pytorch_model"):
314
  filename_prefix = "diffusion_pytorch_model"
315
  index_filename = "diffusion_pytorch_model.safetensors.index.json"
 
 
 
 
 
316
  else:
317
- # Default for LLMs, Text Encoders, etc.
318
  filename_prefix = "model"
319
  index_filename = "model.safetensors.index.json"
320
 
321
- print(f"Detected naming convention: {filename_prefix} (Index: {index_filename})")
322
 
323
- # Initialize Buffer with detected prefix
324
- buffer = ShardBuffer(shard_size, TempDir, output_repo, base_subfolder, hf_token, filename_prefix=filename_prefix)
 
 
 
 
 
 
 
 
325
 
326
  for i, shard_file in enumerate(input_shards):
327
- progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
328
 
329
- local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
330
-
331
- with MemoryEfficientSafeOpen(local_shard) as f:
332
  keys = f.keys()
333
  for k in keys:
334
  v = f.get_tensor(k)
@@ -336,28 +369,24 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
336
  lora_keys = set(lora_pairs.keys())
337
  match = None
338
 
339
- # Matching Logic (Exact + Heuristic for QKV)
340
- if base_stem in lora_keys:
341
- match = lora_pairs[base_stem]
342
- else:
343
- if "to_q" in base_stem:
344
  qkv_stem = base_stem.replace("to_q", "qkv")
345
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
346
- elif "to_k" in base_stem:
347
  qkv_stem = base_stem.replace("to_k", "qkv")
348
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
349
- elif "to_v" in base_stem:
350
  qkv_stem = base_stem.replace("to_v", "qkv")
351
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
352
 
353
  if match and "down" in match and "up" in match:
354
  down = match["down"]
355
  up = match["up"]
356
- alpha = match["alpha"]
357
- rank = match["rank"]
358
- scaling = scale * (alpha / rank)
359
 
360
- # Handle Conv 1x1 squeeze
361
  if len(v.shape) == 4 and len(down.shape) == 2:
362
  down = down.unsqueeze(-1).unsqueeze(-1)
363
  up = up.unsqueeze(-1).unsqueeze(-1)
@@ -373,9 +402,7 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
373
  delta = delta * scaling
374
  valid_delta = True
375
 
376
- # Shape Slicing Logic
377
- if delta.shape == v.shape:
378
- pass
379
  elif delta.shape[0] == v.shape[0] * 3:
380
  chunk = v.shape[0]
381
  if "to_q" in k: delta = delta[0:chunk, ...]
@@ -384,8 +411,7 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
384
  else: valid_delta = False
385
  elif delta.numel() == v.numel():
386
  delta = delta.reshape(v.shape)
387
- else:
388
- valid_delta = False
389
 
390
  if valid_delta:
391
  v = v.to(dtype)
@@ -397,23 +423,19 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
397
  buffer.add_tensor(k, v)
398
  del v
399
 
400
- os.remove(local_shard)
401
  gc.collect()
402
 
403
  buffer.flush()
404
 
405
- # Upload Index (Using the dynamically determined index filename)
406
- print(f"Uploading Index: {index_filename}")
407
-
408
- index_data = {
409
- "metadata": {"total_size": buffer.total_model_size},
410
- "weight_map": buffer.index_map
411
- }
412
 
413
  with open(TempDir / index_filename, "w") as f:
414
  json.dump(index_data, f, indent=4)
415
 
416
- path_in_repo = f"{base_subfolder}/{index_filename}" if base_subfolder else index_filename
417
  api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
418
 
419
  cleanup_temp()
 
1
+ MERGE APP EDIT:
2
+
3
  import gradio as gr
4
  import torch
5
  import os
 
156
  self.output_repo = output_repo
157
  self.subfolder = subfolder
158
  self.hf_token = hf_token
159
+ self.filename_prefix = filename_prefix
160
  self.buffer = []
161
  self.current_bytes = 0
162
  self.shard_count = 0
163
  self.index_map = {}
164
+ self.total_size = 0 # Accumulates total model size for index.json
165
 
166
  def add_tensor(self, key, tensor):
167
+ # Determine bytes for size calculation and storage
168
  if tensor.dtype == torch.bfloat16:
169
  raw_bytes = tensor.view(torch.int16).numpy().tobytes()
170
  dtype_str = "BF16"
 
176
  dtype_str = "F32"
177
 
178
  size = len(raw_bytes)
179
+
180
  self.buffer.append({
181
  "key": key,
182
  "data": raw_bytes,
183
  "dtype": dtype_str,
184
  "shape": tensor.shape
185
  })
186
+
187
  self.current_bytes += size
188
+ self.total_size += size # Explicitly increment total size
189
 
190
  if self.current_bytes >= self.max_bytes:
191
  self.flush()
 
194
  if not self.buffer: return
195
  self.shard_count += 1
196
 
197
+ # Naming: prefix-0000X.safetensors
198
+ # This is standard for indexed loading.
199
  filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
200
 
201
  # Proper Subfolder Handling
 
212
  "data_offsets": [current_offset, current_offset + len(item["data"])]
213
  }
214
  current_offset += len(item["data"])
215
+ self.index_map[item["key"]] = filename # Relative filename for index
216
 
217
  header_json = json.dumps(header).encode('utf-8')
218
 
 
231
  self.current_bytes = 0
232
  gc.collect()
233
 
234
+ def download_lora_smart(input_str, token):
235
  """
236
+ Handles Repo IDs (user/repo) and Direct URLs.
 
237
  """
238
+ local_path = TempDir / "adapter.safetensors"
239
+
240
+ # 1. Direct URL (Private/Public)
241
+ if input_str.startswith("http"):
242
+ print(f"Downloading LoRA from URL: {input_str}")
243
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
244
+ try:
245
+ response = requests.get(input_str, stream=True, headers=headers, timeout=30)
246
+ response.raise_for_status()
247
+ with open(local_path, 'wb') as f:
248
+ for chunk in response.iter_content(chunk_size=8192):
249
+ f.write(chunk)
250
+ # Basic validation
251
+ with open(local_path, "rb") as f:
252
+ if len(f.read(8)) == 8: return local_path
253
+ except Exception as e:
254
+ print(f"URL download failed: {e}. Trying as Repo ID...")
255
+
256
+ # 2. Repo ID (Fallback or Primary)
257
+ # If the user entered a repo ID (e.g. "AlekseyCalvin/MyLora"), this catches it.
258
+ print(f"Attempting download from Hub Repo: {input_str}")
259
  try:
260
+ # Try finding the specific file
261
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
262
+ target_file = None
263
+
264
+ try:
265
+ files = list_repo_files(repo_id=input_str, token=token)
266
+ safetensors = [f for f in files if f.endswith(".safetensors")]
267
+ for c in candidates:
268
+ if c in safetensors:
269
+ target_file = c
270
+ break
271
+ if not target_file and safetensors:
272
+ target_file = safetensors[0]
273
+ except:
274
+ # If listing fails, try default
275
+ target_file = "adapter_model.safetensors"
276
+
277
+ hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
278
 
279
+ # Rename to generic name
280
+ downloaded = TempDir / target_file
281
+ if downloaded != local_path:
282
+ if local_path.exists(): os.remove(local_path)
283
+ shutil.move(downloaded, local_path)
284
 
285
+ return local_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  except Exception as e:
287
+ raise ValueError(f"Failed to download LoRA from {input_str}. \nError: {e}")
288
 
289
  def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
290
  cleanup_temp()
 
295
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
296
  except Exception as e: return f"Error creating repo: {e}"
297
 
298
+ # Define modes
299
+ output_subfolder = base_subfolder if base_subfolder else ""
300
+
301
+ # 2. Clone Structure
302
  if structure_repo:
303
+ print(f"Cloning structure from {structure_repo}...")
304
+ # Ignore the folder we are overwriting (if any)
305
+ ignore = output_subfolder if output_subfolder else None
306
+ # Root merge mode (LLM) usually implies we skip weights in the root
307
+ is_root_merge = not bool(output_subfolder)
308
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=is_root_merge)
309
+
310
+ # 3. Download Input Shards
311
+ progress(0.1, desc="Downloading Base Model...")
312
  try:
313
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
314
+ except Exception as e: return f"Error accessing base repo: {e}"
 
 
315
 
 
 
 
 
 
316
  input_shards = []
317
  for f in files:
318
+ if f.endswith(".safetensors"):
319
+ # Filter by subfolder if specified
320
+ if output_subfolder and not f.startswith(output_subfolder): continue
321
+
322
+ local_path = TempDir / "input_shards" / os.path.basename(f)
323
+ os.makedirs(local_path.parent, exist_ok=True)
324
+
325
+ hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local_path.parent, local_dir_use_symlinks=False)
326
+
327
+ # Locate file (handle nested download paths)
328
+ found = list(local_path.parent.rglob(os.path.basename(f)))
329
+ if found: input_shards.append(found[0])
330
+
331
  if not input_shards: return "No base safetensors found in specified location."
 
332
  input_shards.sort()
333
 
334
+ # --- NAMING CONVENTION LOGIC ---
335
+ # 1. Check for Diffusers specific subfolders -> force 'diffusion_pytorch_model'
336
+ if output_subfolder in ["transformer", "unet"]:
 
 
 
 
 
 
 
337
  filename_prefix = "diffusion_pytorch_model"
338
  index_filename = "diffusion_pytorch_model.safetensors.index.json"
339
+ # 2. Check input file naming -> adopt input convention
340
+ elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]):
341
+ filename_prefix = "diffusion_pytorch_model"
342
+ index_filename = "diffusion_pytorch_model.safetensors.index.json"
343
+ # 3. Default to LLM style
344
  else:
 
345
  filename_prefix = "model"
346
  index_filename = "model.safetensors.index.json"
347
 
348
+ print(f"Naming scheme: {filename_prefix} (Index: {index_filename})")
349
 
350
+ # 4. Load LoRA
351
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
352
+ try:
353
+ progress(0.15, desc="Downloading LoRA...")
354
+ lora_path = download_lora_smart(lora_input, hf_token)
355
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
356
+ except Exception as e: return f"Error loading LoRA: {e}"
357
+
358
+ # 5. Stream Process
359
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
360
 
361
  for i, shard_file in enumerate(input_shards):
362
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
363
 
364
+ with MemoryEfficientSafeOpen(shard_file) as f:
 
 
365
  keys = f.keys()
366
  for k in keys:
367
  v = f.get_tensor(k)
 
369
  lora_keys = set(lora_pairs.keys())
370
  match = None
371
 
372
+ if base_stem in lora_keys: match = lora_pairs[base_stem]
373
+ # QKV Heuristics (Z-Image/Flux specific)
374
+ if not match:
375
+ if "to_q" in base_stem:
 
376
  qkv_stem = base_stem.replace("to_q", "qkv")
377
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
378
+ elif "to_k" in base_stem:
379
  qkv_stem = base_stem.replace("to_k", "qkv")
380
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
381
+ elif "to_v" in base_stem:
382
  qkv_stem = base_stem.replace("to_v", "qkv")
383
  if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
384
 
385
  if match and "down" in match and "up" in match:
386
  down = match["down"]
387
  up = match["up"]
388
+ scaling = scale * (match["alpha"] / match["rank"])
 
 
389
 
 
390
  if len(v.shape) == 4 and len(down.shape) == 2:
391
  down = down.unsqueeze(-1).unsqueeze(-1)
392
  up = up.unsqueeze(-1).unsqueeze(-1)
 
402
  delta = delta * scaling
403
  valid_delta = True
404
 
405
+ if delta.shape == v.shape: pass
 
 
406
  elif delta.shape[0] == v.shape[0] * 3:
407
  chunk = v.shape[0]
408
  if "to_q" in k: delta = delta[0:chunk, ...]
 
411
  else: valid_delta = False
412
  elif delta.numel() == v.numel():
413
  delta = delta.reshape(v.shape)
414
+ else: valid_delta = False
 
415
 
416
  if valid_delta:
417
  v = v.to(dtype)
 
423
  buffer.add_tensor(k, v)
424
  del v
425
 
426
+ os.remove(shard_file)
427
  gc.collect()
428
 
429
  buffer.flush()
430
 
431
+ # 6. Upload Index (Now using correct total_size)
432
+ print(f"Uploading Index: {index_filename} (Total Size: {buffer.total_size})")
433
+ index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
 
 
 
 
434
 
435
  with open(TempDir / index_filename, "w") as f:
436
  json.dump(index_data, f, indent=4)
437
 
438
+ path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
439
  api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
440
 
441
  cleanup_temp()