dagloop5 commited on
Commit
341f0ac
·
verified ·
1 Parent(s): aa36a4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py CHANGED
@@ -295,6 +295,73 @@ checkpoint_path = hf_hub_download(
295
  local_dir=str(weights_dir),
296
  local_dir_use_symlinks=False,
297
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
299
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
300
 
 
295
  local_dir=str(weights_dir),
296
  local_dir_use_symlinks=False,
297
  )
298
+
299
+ print("[Gemma] Setting up abliterated Gemma text encoder...")
300
+ MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
301
+ gemma_root = "/tmp/abliterated_gemma"
302
+ os.makedirs(gemma_root, exist_ok=True)
303
+
304
+ gemma_official_dir = snapshot_download(
305
+ repo_id=GEMMA_REPO,
306
+ ignore_patterns=["*.safetensors", "*.safetensors.index.json"],
307
+ )
308
+
309
+ for fname in os.listdir(gemma_official_dir):
310
+ src = os.path.join(gemma_official_dir, fname)
311
+ dst = os.path.join(gemma_root, fname)
312
+ if os.path.isfile(src) and not fname.endswith(".safetensors") and fname != "model.safetensors.index.json":
313
+ if not os.path.exists(dst):
314
+ os.symlink(src, dst)
315
+
316
+ if os.path.exists(MERGED_WEIGHTS):
317
+ print("[Gemma] Using cached merged weights")
318
+ else:
319
+ abliterated_weights_path = hf_hub_download(
320
+ repo_id=GEMMA_ABLITERATED_REPO,
321
+ filename=GEMMA_ABLITERATED_FILE,
322
+ )
323
+ index_path = hf_hub_download(
324
+ repo_id=GEMMA_REPO,
325
+ filename="model.safetensors.index.json"
326
+ )
327
+ with open(index_path) as f:
328
+ weight_index = json.load(f)
329
+
330
+ vision_keys = {}
331
+ for key, shard in weight_index["weight_map"].items():
332
+ if "vision_tower" in key or "multi_modal_projector" in key:
333
+ vision_keys[key] = shard
334
+ needed_shards = set(vision_keys.values())
335
+
336
+ shard_paths = {}
337
+ for shard_name in needed_shards:
338
+ shard_paths[shard_name] = hf_hub_download(
339
+ repo_id=GEMMA_REPO,
340
+ filename=shard_name
341
+ )
342
+
343
+ _fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
344
+ raw = load_file(abliterated_weights_path)
345
+ merged = {}
346
+ for key, tensor in raw.items():
347
+ t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
348
+ merged[f"language_model.{key}"] = t
349
+ del raw
350
+
351
+ for key, shard_name in vision_keys.items():
352
+ with safe_open(shard_paths[shard_name], framework="pt") as f:
353
+ merged[key] = f.get_tensor(key)
354
+
355
+ save_file(merged, MERGED_WEIGHTS)
356
+ del merged
357
+ gc.collect()
358
+
359
+ weight_link = os.path.join(gemma_root, "model.safetensors")
360
+ if os.path.exists(weight_link):
361
+ os.remove(weight_link)
362
+ os.symlink(MERGED_WEIGHTS, weight_link)
363
+ print(f"[Gemma] Root ready: {gemma_root}")
364
+
365
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
366
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
367