update app
Browse files
app.py
CHANGED
|
@@ -32,9 +32,11 @@ LORA_CKPT_NAME = None
|
|
| 32 |
SQUARE_DIM = 1024
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
| 34 |
|
| 35 |
-
# ----------------- Globals (
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# ----------------- Helper: Download Logic -----------------
|
| 40 |
def download_if_needed(filename):
|
|
@@ -46,10 +48,11 @@ def download_if_needed(filename):
|
|
| 46 |
logging.info(f"Found local file: {filename}")
|
| 47 |
return filename
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
try:
|
| 51 |
path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
|
| 52 |
-
logging.info(f"Downloaded to: {path}")
|
| 53 |
return path
|
| 54 |
except Exception as e:
|
| 55 |
raise FileNotFoundError(f"Could not find {filename} locally or in HF repo {HF_REPO_ID}. Error: {e}")
|
|
@@ -170,23 +173,24 @@ def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tup
|
|
| 170 |
up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
|
| 171 |
return up[0, 0]
|
| 172 |
|
| 173 |
-
# ----------------- Model Loading -----------------
|
| 174 |
|
| 175 |
-
def
|
| 176 |
"""
|
| 177 |
-
|
|
|
|
| 178 |
"""
|
| 179 |
-
global
|
| 180 |
|
| 181 |
-
if
|
| 182 |
-
return
|
| 183 |
|
| 184 |
-
|
| 185 |
|
| 186 |
# 1. Base SAM2 Model
|
| 187 |
base_path = download_if_needed(BASE_CKPT_NAME)
|
| 188 |
|
| 189 |
-
#
|
| 190 |
model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
|
| 191 |
|
| 192 |
# 2. Fine-tuned Weights
|
|
@@ -194,9 +198,8 @@ def load_models_lazy():
|
|
| 194 |
sd = torch.load(final_path, map_location="cpu")
|
| 195 |
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
model
|
| 199 |
-
MODEL_SAM = model
|
| 200 |
|
| 201 |
# 3. PLM Adapter
|
| 202 |
C = model.sam_mask_decoder.transformer_dim
|
|
@@ -210,7 +213,7 @@ def load_models_lazy():
|
|
| 210 |
lora_alpha=32,
|
| 211 |
lora_dropout=0.05,
|
| 212 |
dtype=torch.bfloat16,
|
| 213 |
-
device="cpu",
|
| 214 |
)
|
| 215 |
|
| 216 |
plm_path = download_if_needed(PLM_CKPT_NAME)
|
|
@@ -221,31 +224,30 @@ def load_models_lazy():
|
|
| 221 |
lora_path = download_if_needed(LORA_CKPT_NAME)
|
| 222 |
plm.load_lora(lora_path)
|
| 223 |
|
| 224 |
-
# Move PLM to CUDA
|
| 225 |
-
plm.to("cuda")
|
| 226 |
plm.eval()
|
| 227 |
-
|
|
|
|
| 228 |
|
| 229 |
-
print("Models loaded successfully.")
|
| 230 |
-
return MODEL_SAM, PLM
|
| 231 |
|
| 232 |
-
|
| 233 |
-
@spaces.GPU(duration=180)
|
| 234 |
def run_prediction(image_pil, text_prompt):
|
| 235 |
if image_pil is None or not text_prompt:
|
| 236 |
return None, None
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
predictor = None
|
| 239 |
|
| 240 |
try:
|
| 241 |
-
#
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
# 2. Instantiate Predictor
|
| 245 |
-
model_sam.to("cuda")
|
| 246 |
-
plm.to("cuda")
|
| 247 |
-
|
| 248 |
-
predictor = SAM2ImagePredictor(model_sam)
|
| 249 |
|
| 250 |
# 3. Preprocess Image
|
| 251 |
rgb_orig = np.array(image_pil.convert("RGB"))
|
|
@@ -263,7 +265,7 @@ def run_prediction(image_pil, text_prompt):
|
|
| 263 |
temp_path = "temp_input.jpg"
|
| 264 |
image_pil.save(temp_path)
|
| 265 |
|
| 266 |
-
sp, dp =
|
| 267 |
|
| 268 |
# 6. Prepare SAM2 Decoder inputs
|
| 269 |
dec = predictor.model.sam_mask_decoder
|
|
@@ -306,7 +308,13 @@ def run_prediction(image_pil, text_prompt):
|
|
| 306 |
raise e
|
| 307 |
|
| 308 |
finally:
|
| 309 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
if predictor:
|
| 311 |
del predictor
|
| 312 |
torch.cuda.empty_cache()
|
|
|
|
| 32 |
SQUARE_DIM = 1024
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
| 34 |
|
| 35 |
+
# ----------------- Globals (Ram Cache) -----------------
|
| 36 |
+
# We keep these on CPU globally so they persist between runs
|
| 37 |
+
# without taking up GPU memory (which gets reset).
|
| 38 |
+
MODEL_SAM_CPU = None
|
| 39 |
+
PLM_CPU = None
|
| 40 |
|
| 41 |
# ----------------- Helper: Download Logic -----------------
|
| 42 |
def download_if_needed(filename):
|
|
|
|
| 48 |
logging.info(f"Found local file: {filename}")
|
| 49 |
return filename
|
| 50 |
|
| 51 |
+
# hf_hub_download checks the cache automatically.
|
| 52 |
+
# It won't re-download if the file is already in the HF cache.
|
| 53 |
+
logging.info(f"Checking HF Cache for {filename}...")
|
| 54 |
try:
|
| 55 |
path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
|
|
|
|
| 56 |
return path
|
| 57 |
except Exception as e:
|
| 58 |
raise FileNotFoundError(f"Could not find {filename} locally or in HF repo {HF_REPO_ID}. Error: {e}")
|
|
|
|
| 173 |
up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
|
| 174 |
return up[0, 0]
|
| 175 |
|
| 176 |
+
# ----------------- Model Loading (CPU Caching) -----------------
|
| 177 |
|
| 178 |
+
def ensure_models_loaded_on_cpu():
|
| 179 |
"""
|
| 180 |
+
Ensures models are loaded in Global CPU RAM.
|
| 181 |
+
This avoids re-reading from disk/cache on every run.
|
| 182 |
"""
|
| 183 |
+
global MODEL_SAM_CPU, PLM_CPU
|
| 184 |
|
| 185 |
+
if MODEL_SAM_CPU is not None and PLM_CPU is not None:
|
| 186 |
+
return # Already loaded in RAM
|
| 187 |
|
| 188 |
+
logging.info("Loading models into CPU RAM (this happens once)...")
|
| 189 |
|
| 190 |
# 1. Base SAM2 Model
|
| 191 |
base_path = download_if_needed(BASE_CKPT_NAME)
|
| 192 |
|
| 193 |
+
# Build on CPU
|
| 194 |
model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
|
| 195 |
|
| 196 |
# 2. Fine-tuned Weights
|
|
|
|
| 198 |
sd = torch.load(final_path, map_location="cpu")
|
| 199 |
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 200 |
|
| 201 |
+
# Save to Global (CPU)
|
| 202 |
+
MODEL_SAM_CPU = model
|
|
|
|
| 203 |
|
| 204 |
# 3. PLM Adapter
|
| 205 |
C = model.sam_mask_decoder.transformer_dim
|
|
|
|
| 213 |
lora_alpha=32,
|
| 214 |
lora_dropout=0.05,
|
| 215 |
dtype=torch.bfloat16,
|
| 216 |
+
device="cpu",
|
| 217 |
)
|
| 218 |
|
| 219 |
plm_path = download_if_needed(PLM_CKPT_NAME)
|
|
|
|
| 224 |
lora_path = download_if_needed(LORA_CKPT_NAME)
|
| 225 |
plm.load_lora(lora_path)
|
| 226 |
|
|
|
|
|
|
|
| 227 |
plm.eval()
|
| 228 |
+
PLM_CPU = plm
|
| 229 |
+
logging.info("Models successfully loaded into CPU RAM.")
|
| 230 |
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
@spaces.GPU(duration=120)
|
|
|
|
| 233 |
def run_prediction(image_pil, text_prompt):
|
| 234 |
if image_pil is None or not text_prompt:
|
| 235 |
return None, None
|
| 236 |
|
| 237 |
+
# 1. Ensure models are in RAM (Fast check)
|
| 238 |
+
ensure_models_loaded_on_cpu()
|
| 239 |
+
|
| 240 |
+
# 2. Move to GPU (The only 'loading' cost per run)
|
| 241 |
+
# We rely on the global variables
|
| 242 |
+
logging.info("Moving models to GPU...")
|
| 243 |
+
MODEL_SAM_CPU.to("cuda")
|
| 244 |
+
PLM_CPU.to("cuda")
|
| 245 |
+
|
| 246 |
predictor = None
|
| 247 |
|
| 248 |
try:
|
| 249 |
+
# Instantiate Predictor on GPU
|
| 250 |
+
predictor = SAM2ImagePredictor(MODEL_SAM_CPU)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
# 3. Preprocess Image
|
| 253 |
rgb_orig = np.array(image_pil.convert("RGB"))
|
|
|
|
| 265 |
temp_path = "temp_input.jpg"
|
| 266 |
image_pil.save(temp_path)
|
| 267 |
|
| 268 |
+
sp, dp = PLM_CPU([text_prompt], H_feat, W_feat, [temp_path])
|
| 269 |
|
| 270 |
# 6. Prepare SAM2 Decoder inputs
|
| 271 |
dec = predictor.model.sam_mask_decoder
|
|
|
|
| 308 |
raise e
|
| 309 |
|
| 310 |
finally:
|
| 311 |
+
# CRITICAL: Move models back to CPU
|
| 312 |
+
# This preserves the Global Variable on CPU RAM for the next run.
|
| 313 |
+
# If we leave them on CUDA, they might be lost when ZeroGPU releases the device.
|
| 314 |
+
logging.info("Moving models back to CPU...")
|
| 315 |
+
MODEL_SAM_CPU.to("cpu")
|
| 316 |
+
PLM_CPU.to("cpu")
|
| 317 |
+
|
| 318 |
if predictor:
|
| 319 |
del predictor
|
| 320 |
torch.cuda.empty_cache()
|