Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -40,7 +40,6 @@ sys.modules['sdib.utils.utils'] = mock
|
|
| 40 |
|
| 41 |
# Configuration
|
| 42 |
PRUNING_RATIOS = [10, 15, 20]
|
| 43 |
-
LORA_CHECKPOINT_STEP = os.getenv("LORA_CHECKPOINT_STEP")
|
| 44 |
|
| 45 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -62,8 +61,8 @@ for ratio in PRUNING_RATIOS:
|
|
| 62 |
try:
|
| 63 |
print(f"Loading {ratio}% pruned model...")
|
| 64 |
model_file = hf_hub_download(
|
| 65 |
-
repo_id="LWZ19/
|
| 66 |
-
filename=f"pruned_model_{ratio}.pkl"
|
| 67 |
)
|
| 68 |
|
| 69 |
with open(model_file, "rb") as f:
|
|
@@ -80,10 +79,10 @@ for ratio in PRUNING_RATIOS:
|
|
| 80 |
print("📥 Preloading LoRA checkpoint for 20% pruning ratio...")
|
| 81 |
try:
|
| 82 |
lora_repo_path = snapshot_download(
|
| 83 |
-
repo_id="LWZ19/
|
| 84 |
-
allow_patterns=[f"lora/
|
| 85 |
)
|
| 86 |
-
lora_weights = load_file(os.path.join(lora_repo_path, "lora",
|
| 87 |
print("✅ LoRA checkpoint loaded!")
|
| 88 |
except Exception as e:
|
| 89 |
print(f"❌ Failed to load LoRA checkpoint: {e}")
|
|
|
|
| 40 |
|
| 41 |
# Configuration
|
| 42 |
PRUNING_RATIOS = [10, 15, 20]
|
|
|
|
| 43 |
|
| 44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 61 |
try:
|
| 62 |
print(f"Loading {ratio}% pruned model...")
|
| 63 |
model_file = hf_hub_download(
|
| 64 |
+
repo_id="LWZ19/flux_prune",
|
| 65 |
+
filename=f"dev/pruned_model_{ratio}.pkl"
|
| 66 |
)
|
| 67 |
|
| 68 |
with open(model_file, "rb") as f:
|
|
|
|
| 79 |
print("📥 Preloading LoRA checkpoint for 20% pruning ratio...")
|
| 80 |
try:
|
| 81 |
lora_repo_path = snapshot_download(
|
| 82 |
+
repo_id="LWZ19/flux_retrain_weights",
|
| 83 |
+
allow_patterns=[f"dev/lora/prune_20/*"]
|
| 84 |
)
|
| 85 |
+
lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", "prune_20", LORA_WEIGHT_NAME_SAFE))
|
| 86 |
print("✅ LoRA checkpoint loaded!")
|
| 87 |
except Exception as e:
|
| 88 |
print(f"❌ Failed to load LoRA checkpoint: {e}")
|