aadarsh99 commited on
Commit
0a57513
·
1 Parent(s): 37ff04d

load lora weights

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  import torch.nn.functional as F
11
  import gradio as gr
12
  from PIL import Image, ImageFilter, ImageChops
13
- from huggingface_hub import hf_hub_download
14
  import spaces
15
 
16
  # --- IMPORT YOUR CUSTOM MODULES ---
@@ -28,6 +28,7 @@ SAM2_CONFIG = "sam2_hiera_l.yaml"
28
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
29
  FINE_TUNED_SAM = "fine_tuned_sam2_batched_90000.torch"
30
  FINE_TUNED_PLM = "fine_tuned_sam2_batched_plm_90000.torch"
 
31
 
32
  SQUARE_DIM = 1024
33
 
@@ -85,7 +86,7 @@ def ensure_models_loaded():
85
  sd = torch.load(sam_ckpt_path, map_location="cpu")
86
  model.load_state_dict(sd.get("model", sd), strict=True)
87
 
88
- # 2. Load PLM Adapter
89
  plm_path = download_if_needed(REPO_ID, FINE_TUNED_PLM)
90
  plm = PLMLanguageAdapter(
91
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
@@ -96,6 +97,21 @@ def ensure_models_loaded():
96
  )
97
  plm_sd = torch.load(plm_path, map_location="cpu")
98
  plm.load_state_dict(plm_sd["plm"], strict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  plm.eval()
100
 
101
  MODEL_CACHE["sam"] = model
 
10
  import torch.nn.functional as F
11
  import gradio as gr
12
  from PIL import Image, ImageFilter, ImageChops
13
+ from huggingface_hub import hf_hub_download, snapshot_download # <--- Added snapshot_download
14
  import spaces
15
 
16
  # --- IMPORT YOUR CUSTOM MODULES ---
 
28
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
29
  FINE_TUNED_SAM = "fine_tuned_sam2_batched_90000.torch"
30
  FINE_TUNED_PLM = "fine_tuned_sam2_batched_plm_90000.torch"
31
+ FINE_TUNED_LORA = "lora_plm_adapter_90000" # Folder name in the HF Repo
32
 
33
  SQUARE_DIM = 1024
34
 
 
86
  sd = torch.load(sam_ckpt_path, map_location="cpu")
87
  model.load_state_dict(sd.get("model", sd), strict=True)
88
 
89
+ # 2. Load PLM Adapter Base
90
  plm_path = download_if_needed(REPO_ID, FINE_TUNED_PLM)
91
  plm = PLMLanguageAdapter(
92
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
 
97
  )
98
  plm_sd = torch.load(plm_path, map_location="cpu")
99
  plm.load_state_dict(plm_sd["plm"], strict=True)
100
+
101
+ # 3. Load LoRA Weights
102
+ try:
103
+ logging.info(f"Downloading LoRA folder: {FINE_TUNED_LORA}...")
104
+ # snapshot_download returns the root cache folder; we use allow_patterns to get just the LoRA folder
105
+ cache_root = snapshot_download(repo_id=REPO_ID, allow_patterns=f"{FINE_TUNED_LORA}/*")
106
+
107
+ # Construct the full path to the directory containing the LoRA files
108
+ lora_dir_path = os.path.join(cache_root, FINE_TUNED_LORA)
109
+
110
+ logging.info(f"Loading LoRA from {lora_dir_path}...")
111
+ plm.load_lora(lora_dir_path)
112
+ except Exception as e:
113
+ raise RuntimeError(f"Failed to load LoRA weights: {e}")
114
+
115
  plm.eval()
116
 
117
  MODEL_CACHE["sam"] = model