aadarsh99 commited on
Commit
9d1694a
·
1 Parent(s): 8f65995

update app

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -32,9 +32,11 @@ LORA_CKPT_NAME = None
32
  SQUARE_DIM = 1024
33
  logging.basicConfig(level=logging.INFO)
34
 
35
- # ----------------- Globals (Lazy Loading) -----------------
36
- MODEL_SAM = None
37
- PLM = None
 
 
38
 
39
  # ----------------- Helper: Download Logic -----------------
40
  def download_if_needed(filename):
@@ -46,10 +48,11 @@ def download_if_needed(filename):
46
  logging.info(f"Found local file: {filename}")
47
  return filename
48
 
49
- logging.info(f"{filename} not found locally. Downloading from {HF_REPO_ID}...")
 
 
50
  try:
51
  path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
52
- logging.info(f"Downloaded to: {path}")
53
  return path
54
  except Exception as e:
55
  raise FileNotFoundError(f"Could not find {filename} locally or in HF repo {HF_REPO_ID}. Error: {e}")
@@ -170,23 +173,24 @@ def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tup
170
  up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
171
  return up[0, 0]
172
 
173
- # ----------------- Model Loading -----------------
174
 
175
- def load_models_lazy():
176
  """
177
- Loads the models. This must be called INSIDE the @spaces.GPU context.
 
178
  """
179
- global MODEL_SAM, PLM
180
 
181
- if MODEL_SAM is not None and PLM is not None:
182
- return MODEL_SAM, PLM
183
 
184
- print("Lazy loading models inside GPU context...")
185
 
186
  # 1. Base SAM2 Model
187
  base_path = download_if_needed(BASE_CKPT_NAME)
188
 
189
- # Init on CPU
190
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
191
 
192
  # 2. Fine-tuned Weights
@@ -194,9 +198,8 @@ def load_models_lazy():
194
  sd = torch.load(final_path, map_location="cpu")
195
  model.load_state_dict(sd.get("model", sd), strict=True)
196
 
197
- # Move SAM to CUDA now
198
- model.to("cuda")
199
- MODEL_SAM = model
200
 
201
  # 3. PLM Adapter
202
  C = model.sam_mask_decoder.transformer_dim
@@ -210,7 +213,7 @@ def load_models_lazy():
210
  lora_alpha=32,
211
  lora_dropout=0.05,
212
  dtype=torch.bfloat16,
213
- device="cpu", # Init on CPU
214
  )
215
 
216
  plm_path = download_if_needed(PLM_CKPT_NAME)
@@ -221,31 +224,30 @@ def load_models_lazy():
221
  lora_path = download_if_needed(LORA_CKPT_NAME)
222
  plm.load_lora(lora_path)
223
 
224
- # Move PLM to CUDA
225
- plm.to("cuda")
226
  plm.eval()
227
- PLM = plm
 
228
 
229
- print("Models loaded successfully.")
230
- return MODEL_SAM, PLM
231
 
232
-
233
- @spaces.GPU(duration=180)
234
  def run_prediction(image_pil, text_prompt):
235
  if image_pil is None or not text_prompt:
236
  return None, None
237
 
 
 
 
 
 
 
 
 
 
238
  predictor = None
239
 
240
  try:
241
- # 1. Ensure models are loaded (Lazy Load)
242
- model_sam, plm = load_models_lazy()
243
-
244
- # 2. Instantiate Predictor
245
- model_sam.to("cuda")
246
- plm.to("cuda")
247
-
248
- predictor = SAM2ImagePredictor(model_sam)
249
 
250
  # 3. Preprocess Image
251
  rgb_orig = np.array(image_pil.convert("RGB"))
@@ -263,7 +265,7 @@ def run_prediction(image_pil, text_prompt):
263
  temp_path = "temp_input.jpg"
264
  image_pil.save(temp_path)
265
 
266
- sp, dp = plm([text_prompt], H_feat, W_feat, [temp_path])
267
 
268
  # 6. Prepare SAM2 Decoder inputs
269
  dec = predictor.model.sam_mask_decoder
@@ -306,7 +308,13 @@ def run_prediction(image_pil, text_prompt):
306
  raise e
307
 
308
  finally:
309
- # Cleanup
 
 
 
 
 
 
310
  if predictor:
311
  del predictor
312
  torch.cuda.empty_cache()
 
32
  SQUARE_DIM = 1024
33
  logging.basicConfig(level=logging.INFO)
34
 
35
+ # ----------------- Globals (Ram Cache) -----------------
36
+ # We keep these on CPU globally so they persist between runs
37
+ # without taking up GPU memory (which gets reset).
38
+ MODEL_SAM_CPU = None
39
+ PLM_CPU = None
40
 
41
  # ----------------- Helper: Download Logic -----------------
42
  def download_if_needed(filename):
 
48
  logging.info(f"Found local file: {filename}")
49
  return filename
50
 
51
+ # hf_hub_download checks the cache automatically.
52
+ # It won't re-download if the file is already in the HF cache.
53
+ logging.info(f"Checking HF Cache for {filename}...")
54
  try:
55
  path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
 
56
  return path
57
  except Exception as e:
58
  raise FileNotFoundError(f"Could not find {filename} locally or in HF repo {HF_REPO_ID}. Error: {e}")
 
173
  up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
174
  return up[0, 0]
175
 
176
+ # ----------------- Model Loading (CPU Caching) -----------------
177
 
178
+ def ensure_models_loaded_on_cpu():
179
  """
180
+ Ensures models are loaded in Global CPU RAM.
181
+ This avoids re-reading from disk/cache on every run.
182
  """
183
+ global MODEL_SAM_CPU, PLM_CPU
184
 
185
+ if MODEL_SAM_CPU is not None and PLM_CPU is not None:
186
+ return # Already loaded in RAM
187
 
188
+ logging.info("Loading models into CPU RAM (this happens once)...")
189
 
190
  # 1. Base SAM2 Model
191
  base_path = download_if_needed(BASE_CKPT_NAME)
192
 
193
+ # Build on CPU
194
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
195
 
196
  # 2. Fine-tuned Weights
 
198
  sd = torch.load(final_path, map_location="cpu")
199
  model.load_state_dict(sd.get("model", sd), strict=True)
200
 
201
+ # Save to Global (CPU)
202
+ MODEL_SAM_CPU = model
 
203
 
204
  # 3. PLM Adapter
205
  C = model.sam_mask_decoder.transformer_dim
 
213
  lora_alpha=32,
214
  lora_dropout=0.05,
215
  dtype=torch.bfloat16,
216
+ device="cpu",
217
  )
218
 
219
  plm_path = download_if_needed(PLM_CKPT_NAME)
 
224
  lora_path = download_if_needed(LORA_CKPT_NAME)
225
  plm.load_lora(lora_path)
226
 
 
 
227
  plm.eval()
228
+ PLM_CPU = plm
229
+ logging.info("Models successfully loaded into CPU RAM.")
230
 
 
 
231
 
232
+ @spaces.GPU(duration=120)
 
233
  def run_prediction(image_pil, text_prompt):
234
  if image_pil is None or not text_prompt:
235
  return None, None
236
 
237
+ # 1. Ensure models are in RAM (Fast check)
238
+ ensure_models_loaded_on_cpu()
239
+
240
+ # 2. Move to GPU (The only 'loading' cost per run)
241
+ # We rely on the global variables
242
+ logging.info("Moving models to GPU...")
243
+ MODEL_SAM_CPU.to("cuda")
244
+ PLM_CPU.to("cuda")
245
+
246
  predictor = None
247
 
248
  try:
249
+ # Instantiate Predictor on GPU
250
+ predictor = SAM2ImagePredictor(MODEL_SAM_CPU)
 
 
 
 
 
 
251
 
252
  # 3. Preprocess Image
253
  rgb_orig = np.array(image_pil.convert("RGB"))
 
265
  temp_path = "temp_input.jpg"
266
  image_pil.save(temp_path)
267
 
268
+ sp, dp = PLM_CPU([text_prompt], H_feat, W_feat, [temp_path])
269
 
270
  # 6. Prepare SAM2 Decoder inputs
271
  dec = predictor.model.sam_mask_decoder
 
308
  raise e
309
 
310
  finally:
311
+ # CRITICAL: Move models back to CPU
312
+ # This preserves the Global Variable on CPU RAM for the next run.
313
+ # If we leave them on CUDA, they might be lost when ZeroGPU releases the device.
314
+ logging.info("Moving models back to CPU...")
315
+ MODEL_SAM_CPU.to("cpu")
316
+ PLM_CPU.to("cpu")
317
+
318
  if predictor:
319
  del predictor
320
  torch.cuda.empty_cache()