dagloop5 commited on
Commit
aa36a4f
·
verified ·
1 Parent(s): 0c3d5b1

Update app(workinganddefaultcheckpoint).py

Browse files
Files changed (1) hide show
  1. app(workinganddefaultcheckpoint).py +69 -0
app(workinganddefaultcheckpoint).py CHANGED
@@ -267,6 +267,8 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
 
 
270
 
271
  # Download model checkpoints
272
  print("=" * 80)
@@ -290,6 +292,73 @@ checkpoint_path = hf_hub_download(
290
  local_dir=str(weights_dir),
291
  local_dir_use_symlinks=False,
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
294
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
295
 
 
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
270
+ GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
271
+ GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
272
 
273
  # Download model checkpoints
274
  print("=" * 80)
 
292
  local_dir=str(weights_dir),
293
  local_dir_use_symlinks=False,
294
  )
295
+
296
+ print("[Gemma] Setting up abliterated Gemma text encoder...")
297
+ MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
298
+ gemma_root = "/tmp/abliterated_gemma"
299
+ os.makedirs(gemma_root, exist_ok=True)
300
+
301
+ gemma_official_dir = snapshot_download(
302
+ repo_id=GEMMA_REPO,
303
+ ignore_patterns=["*.safetensors", "*.safetensors.index.json"],
304
+ )
305
+
306
+ for fname in os.listdir(gemma_official_dir):
307
+ src = os.path.join(gemma_official_dir, fname)
308
+ dst = os.path.join(gemma_root, fname)
309
+ if os.path.isfile(src) and not fname.endswith(".safetensors") and fname != "model.safetensors.index.json":
310
+ if not os.path.exists(dst):
311
+ os.symlink(src, dst)
312
+
313
+ if os.path.exists(MERGED_WEIGHTS):
314
+ print("[Gemma] Using cached merged weights")
315
+ else:
316
+ abliterated_weights_path = hf_hub_download(
317
+ repo_id=GEMMA_ABLITERATED_REPO,
318
+ filename=GEMMA_ABLITERATED_FILE,
319
+ )
320
+ index_path = hf_hub_download(
321
+ repo_id=GEMMA_REPO,
322
+ filename="model.safetensors.index.json"
323
+ )
324
+ with open(index_path) as f:
325
+ weight_index = json.load(f)
326
+
327
+ vision_keys = {}
328
+ for key, shard in weight_index["weight_map"].items():
329
+ if "vision_tower" in key or "multi_modal_projector" in key:
330
+ vision_keys[key] = shard
331
+ needed_shards = set(vision_keys.values())
332
+
333
+ shard_paths = {}
334
+ for shard_name in needed_shards:
335
+ shard_paths[shard_name] = hf_hub_download(
336
+ repo_id=GEMMA_REPO,
337
+ filename=shard_name
338
+ )
339
+
340
+ _fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
341
+ raw = load_file(abliterated_weights_path)
342
+ merged = {}
343
+ for key, tensor in raw.items():
344
+ t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
345
+ merged[f"language_model.{key}"] = t
346
+ del raw
347
+
348
+ for key, shard_name in vision_keys.items():
349
+ with safe_open(shard_paths[shard_name], framework="pt") as f:
350
+ merged[key] = f.get_tensor(key)
351
+
352
+ save_file(merged, MERGED_WEIGHTS)
353
+ del merged
354
+ gc.collect()
355
+
356
+ weight_link = os.path.join(gemma_root, "model.safetensors")
357
+ if os.path.exists(weight_link):
358
+ os.remove(weight_link)
359
+ os.symlink(MERGED_WEIGHTS, weight_link)
360
+ print(f"[Gemma] Root ready: {gemma_root}")
361
+
362
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
363
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
364