load lora weights
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ import torch
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
import gradio as gr
|
| 12 |
from PIL import Image, ImageFilter, ImageChops
|
| 13 |
-
from huggingface_hub import hf_hub_download
|
| 14 |
import spaces
|
| 15 |
|
| 16 |
# --- IMPORT YOUR CUSTOM MODULES ---
|
|
@@ -28,6 +28,7 @@ SAM2_CONFIG = "sam2_hiera_l.yaml"
|
|
| 28 |
BASE_CKPT_NAME = "sam2_hiera_large.pt"
|
| 29 |
FINE_TUNED_SAM = "fine_tuned_sam2_batched_90000.torch"
|
| 30 |
FINE_TUNED_PLM = "fine_tuned_sam2_batched_plm_90000.torch"
|
|
|
|
| 31 |
|
| 32 |
SQUARE_DIM = 1024
|
| 33 |
|
|
@@ -85,7 +86,7 @@ def ensure_models_loaded():
|
|
| 85 |
sd = torch.load(sam_ckpt_path, map_location="cpu")
|
| 86 |
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 87 |
|
| 88 |
-
# 2. Load PLM Adapter
|
| 89 |
plm_path = download_if_needed(REPO_ID, FINE_TUNED_PLM)
|
| 90 |
plm = PLMLanguageAdapter(
|
| 91 |
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
|
|
@@ -96,6 +97,21 @@ def ensure_models_loaded():
|
|
| 96 |
)
|
| 97 |
plm_sd = torch.load(plm_path, map_location="cpu")
|
| 98 |
plm.load_state_dict(plm_sd["plm"], strict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
plm.eval()
|
| 100 |
|
| 101 |
MODEL_CACHE["sam"] = model
|
|
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
import gradio as gr
|
| 12 |
from PIL import Image, ImageFilter, ImageChops
|
| 13 |
+
from huggingface_hub import hf_hub_download, snapshot_download # <--- Added snapshot_download
|
| 14 |
import spaces
|
| 15 |
|
| 16 |
# --- IMPORT YOUR CUSTOM MODULES ---
|
|
|
|
| 28 |
BASE_CKPT_NAME = "sam2_hiera_large.pt"
|
| 29 |
FINE_TUNED_SAM = "fine_tuned_sam2_batched_90000.torch"
|
| 30 |
FINE_TUNED_PLM = "fine_tuned_sam2_batched_plm_90000.torch"
|
| 31 |
+
FINE_TUNED_LORA = "lora_plm_adapter_90000" # Folder name in the HF Repo
|
| 32 |
|
| 33 |
SQUARE_DIM = 1024
|
| 34 |
|
|
|
|
| 86 |
sd = torch.load(sam_ckpt_path, map_location="cpu")
|
| 87 |
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 88 |
|
| 89 |
+
# 2. Load PLM Adapter Base
|
| 90 |
plm_path = download_if_needed(REPO_ID, FINE_TUNED_PLM)
|
| 91 |
plm = PLMLanguageAdapter(
|
| 92 |
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
|
|
|
|
| 97 |
)
|
| 98 |
plm_sd = torch.load(plm_path, map_location="cpu")
|
| 99 |
plm.load_state_dict(plm_sd["plm"], strict=True)
|
| 100 |
+
|
| 101 |
+
# 3. Load LoRA Weights
|
| 102 |
+
try:
|
| 103 |
+
logging.info(f"Downloading LoRA folder: {FINE_TUNED_LORA}...")
|
| 104 |
+
# snapshot_download returns the root cache folder; we use allow_patterns to get just the LoRA folder
|
| 105 |
+
cache_root = snapshot_download(repo_id=REPO_ID, allow_patterns=f"{FINE_TUNED_LORA}/*")
|
| 106 |
+
|
| 107 |
+
# Construct the full path to the directory containing the LoRA files
|
| 108 |
+
lora_dir_path = os.path.join(cache_root, FINE_TUNED_LORA)
|
| 109 |
+
|
| 110 |
+
logging.info(f"Loading LoRA from {lora_dir_path}...")
|
| 111 |
+
plm.load_lora(lora_dir_path)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
raise RuntimeError(f"Failed to load LoRA weights: {e}")
|
| 114 |
+
|
| 115 |
plm.eval()
|
| 116 |
|
| 117 |
MODEL_CACHE["sam"] = model
|