aadarsh99 commited on
Commit
39ba93b
·
1 Parent(s): b6001f7

update app

Browse files
Files changed (1) hide show
  1. app.py +36 -31
app.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import hashlib
4
  import sys
5
  import traceback
 
6
 
7
  import cv2
8
  import numpy as np
@@ -10,36 +11,50 @@ import torch
10
  import torch.nn.functional as F
11
  import gradio as gr
12
  from PIL import Image, ImageFilter, ImageChops, ImageDraw
13
- from huggingface_hub import hf_hub_download
14
- import spaces # <--- NEW IMPORT
15
 
16
  # --- IMPORT YOUR CUSTOM MODULES ---
17
- # Ensure the 'sam2' folder and 'plm_adapter_...' file are uploaded to your Space
18
  from sam2.build_sam import build_sam2
19
  from sam2.sam2_image_predictor import SAM2ImagePredictor
20
  from sam2.modeling.sam.mask_decoder import MaskDecoder
21
  from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
22
 
23
  # ----------------- Configuration -----------------
24
- # UPDATE THESE TO MATCH YOUR HF REPO IF YOU STORE WEIGHTS THERE
25
  HF_REPO_ID = "aadarsh99/ConvSeg-Stage1"
26
  SAM2_CONFIG = "sam2_hiera_l.yaml"
27
 
28
- # Checkpoint filenames (assumed to be in the root or downloaded)
29
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
30
- FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch" # Update with your filename
31
- PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch" # Update with your filename
32
- LORA_CKPT_NAME = "lora_plm_adapter_100000" # Set filename if you use LoRA, else None
33
 
34
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
  SQUARE_DIM = 1024
36
-
37
  logging.basicConfig(level=logging.INFO)
38
 
39
  # ----------------- Globals (Lazy Loading) -----------------
40
  MODEL_SAM = None
41
  PLM = None
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ----------------- Overlay Style Helpers -----------------
44
  EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
45
 
@@ -214,21 +229,14 @@ def load_models_lazy():
214
  print("Lazy loading models inside GPU context...")
215
 
216
  # 1. Base SAM2 Model
217
- if not os.path.exists(BASE_CKPT_NAME):
218
- raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
219
-
220
- # On ZeroGPU, we can load to 'cuda' directly, or 'cpu' then move.
221
- # To be safe against the deepcopy error, we load to cpu then move.
222
- # If the deepcopy error persists, we might need to load directly to 'cuda'.
223
- # Let's try CPU load -> move to cuda.
224
 
225
- model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
 
226
 
227
  # 2. Fine-tuned Weights
228
- if not os.path.exists(FINAL_CKPT_NAME):
229
- raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
230
-
231
- sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
232
  model.load_state_dict(sd.get("model", sd), strict=True)
233
 
234
  # Move SAM to CUDA now
@@ -250,14 +258,13 @@ def load_models_lazy():
250
  device="cpu", # Init on CPU
251
  )
252
 
253
- if not os.path.exists(PLM_CKPT_NAME):
254
- raise FileNotFoundError(f"{PLM_CKPT_NAME} not found")
255
-
256
- plm_sd = torch.load(PLM_CKPT_NAME, map_location="cpu")
257
  plm.load_state_dict(plm_sd["plm"], strict=True)
258
 
259
- if LORA_CKPT_NAME and os.path.exists(LORA_CKPT_NAME):
260
- plm.load_lora(LORA_CKPT_NAME)
 
261
 
262
  # Move PLM to CUDA
263
  plm.to("cuda")
@@ -268,7 +275,7 @@ def load_models_lazy():
268
  return MODEL_SAM, PLM
269
 
270
 
271
- @spaces.GPU(duration=120) # Increased duration for first-time load
272
  def run_prediction(image_pil, text_prompt):
273
  if image_pil is None or not text_prompt:
274
  return None, None, None
@@ -280,8 +287,6 @@ def run_prediction(image_pil, text_prompt):
280
  model_sam, plm = load_models_lazy()
281
 
282
  # 2. Instantiate Predictor
283
- # We assume models are already on CUDA from load_models_lazy
284
- # Just to be sure, we can call .to("cuda") again (cheap if already there)
285
  model_sam.to("cuda")
286
  plm.to("cuda")
287
 
 
3
  import hashlib
4
  import sys
5
  import traceback
6
+ import copy
7
 
8
  import cv2
9
  import numpy as np
 
11
  import torch.nn.functional as F
12
  import gradio as gr
13
  from PIL import Image, ImageFilter, ImageChops, ImageDraw
14
+ from huggingface_hub import hf_hub_download # <--- NEW IMPORT
15
+ import spaces
16
 
17
  # --- IMPORT YOUR CUSTOM MODULES ---
 
18
  from sam2.build_sam import build_sam2
19
  from sam2.sam2_image_predictor import SAM2ImagePredictor
20
  from sam2.modeling.sam.mask_decoder import MaskDecoder
21
  from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
22
 
23
  # ----------------- Configuration -----------------
 
24
  HF_REPO_ID = "aadarsh99/ConvSeg-Stage1"
25
  SAM2_CONFIG = "sam2_hiera_l.yaml"
26
 
27
+ # Filenames
28
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
29
+ FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
30
+ PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
31
+ LORA_CKPT_NAME = None
32
 
 
33
  SQUARE_DIM = 1024
 
34
  logging.basicConfig(level=logging.INFO)
35
 
36
  # ----------------- Globals (Lazy Loading) -----------------
37
  MODEL_SAM = None
38
  PLM = None
39
 
40
+ # ----------------- Helper: Download Logic -----------------
41
+ def download_if_needed(filename):
42
+ """
43
+ Checks if file exists locally. If not, downloads from HF Repo.
44
+ Returns the valid path to the file.
45
+ """
46
+ if os.path.exists(filename):
47
+ logging.info(f"Found local file: {filename}")
48
+ return filename
49
+
50
+ logging.info(f"{filename} not found locally. Downloading from {HF_REPO_ID}...")
51
+ try:
52
+ path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
53
+ logging.info(f"Downloaded to: {path}")
54
+ return path
55
+ except Exception as e:
56
+ raise FileNotFoundError(f"Could not find {filename} locally or in HF repo {HF_REPO_ID}. Error: {e}")
57
+
58
  # ----------------- Overlay Style Helpers -----------------
59
  EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
60
 
 
229
  print("Lazy loading models inside GPU context...")
230
 
231
  # 1. Base SAM2 Model
232
+ base_path = download_if_needed(BASE_CKPT_NAME)
 
 
 
 
 
 
233
 
234
+ # Init on CPU to avoid "deepcopy" errors, then move later
235
+ model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
236
 
237
  # 2. Fine-tuned Weights
238
+ final_path = download_if_needed(FINAL_CKPT_NAME)
239
+ sd = torch.load(final_path, map_location="cpu")
 
 
240
  model.load_state_dict(sd.get("model", sd), strict=True)
241
 
242
  # Move SAM to CUDA now
 
258
  device="cpu", # Init on CPU
259
  )
260
 
261
+ plm_path = download_if_needed(PLM_CKPT_NAME)
262
+ plm_sd = torch.load(plm_path, map_location="cpu")
 
 
263
  plm.load_state_dict(plm_sd["plm"], strict=True)
264
 
265
+ if LORA_CKPT_NAME:
266
+ lora_path = download_if_needed(LORA_CKPT_NAME)
267
+ plm.load_lora(lora_path)
268
 
269
  # Move PLM to CUDA
270
  plm.to("cuda")
 
275
  return MODEL_SAM, PLM
276
 
277
 
278
+ @spaces.GPU(duration=180) # Increased duration for download + load
279
  def run_prediction(image_pil, text_prompt):
280
  if image_pil is None or not text_prompt:
281
  return None, None, None
 
287
  model_sam, plm = load_models_lazy()
288
 
289
  # 2. Instantiate Predictor
 
 
290
  model_sam.to("cuda")
291
  plm.to("cuda")
292