LWZ19 commited on
Commit
89dfd45
·
1 Parent(s): 183f9cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -40,7 +40,6 @@ sys.modules['sdib.utils.utils'] = mock
40
 
41
  # Configuration
42
  PRUNING_RATIOS = [10, 15, 20]
43
- LORA_CHECKPOINT_STEP = os.getenv("LORA_CHECKPOINT_STEP")
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
  MAX_SEED = np.iinfo(np.int32).max
@@ -62,8 +61,8 @@ for ratio in PRUNING_RATIOS:
62
  try:
63
  print(f"Loading {ratio}% pruned model...")
64
  model_file = hf_hub_download(
65
- repo_id="LWZ19/flux_dev_prune",
66
- filename=f"pruned_model_{ratio}.pkl"
67
  )
68
 
69
  with open(model_file, "rb") as f:
@@ -80,10 +79,10 @@ for ratio in PRUNING_RATIOS:
80
  print("📥 Preloading LoRA checkpoint for 20% pruning ratio...")
81
  try:
82
  lora_repo_path = snapshot_download(
83
- repo_id="LWZ19/flux_dev_20_ckp_2",
84
- allow_patterns=[f"lora/checkpoint-{LORA_CHECKPOINT_STEP}/*"]
85
  )
86
- lora_weights = load_file(os.path.join(lora_repo_path, "lora", f"checkpoint-{LORA_CHECKPOINT_STEP}", LORA_WEIGHT_NAME_SAFE))
87
  print("✅ LoRA checkpoint loaded!")
88
  except Exception as e:
89
  print(f"❌ Failed to load LoRA checkpoint: {e}")
 
40
 
41
  # Configuration
42
  PRUNING_RATIOS = [10, 15, 20]
 
43
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  MAX_SEED = np.iinfo(np.int32).max
 
61
  try:
62
  print(f"Loading {ratio}% pruned model...")
63
  model_file = hf_hub_download(
64
+ repo_id="LWZ19/flux_prune",
65
+ filename=f"dev/pruned_model_{ratio}.pkl"
66
  )
67
 
68
  with open(model_file, "rb") as f:
 
79
  print("📥 Preloading LoRA checkpoint for 20% pruning ratio...")
80
  try:
81
  lora_repo_path = snapshot_download(
82
+ repo_id="LWZ19/flux_retrain_weights",
83
+ allow_patterns=[f"dev/lora/prune_20/*"]
84
  )
85
+ lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", "prune_20", LORA_WEIGHT_NAME_SAFE))
86
  print("✅ LoRA checkpoint loaded!")
87
  except Exception as e:
88
  print(f"❌ Failed to load LoRA checkpoint: {e}")