update app
Browse files
app.py
CHANGED
|
@@ -21,24 +21,34 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
| 21 |
from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
|
| 22 |
|
| 23 |
# ----------------- Configuration -----------------
|
| 24 |
-
REPO_MAP = {
|
| 25 |
-
"Stage 1": "aadarsh99/ConvSeg-Stage1",
|
| 26 |
-
"Stage 2": "aadarsh99/ConvSeg-Stage2"
|
| 27 |
-
}
|
| 28 |
SAM2_CONFIG = "sam2_hiera_l.yaml"
|
| 29 |
-
|
| 30 |
BASE_CKPT_NAME = "sam2_hiera_large.pt"
|
| 31 |
-
FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
|
| 32 |
-
PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
|
| 33 |
|
| 34 |
SQUARE_DIM = 1024
|
| 35 |
logging.basicConfig(level=logging.INFO)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
"Stage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
}
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
# ----------------- Helper Functions -----------------
|
| 43 |
def download_if_needed(repo_id, filename):
|
| 44 |
try:
|
|
@@ -70,29 +80,34 @@ def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.
|
|
| 70 |
stroke_layer = Image.new("RGBA", base.size, color + (255,))
|
| 71 |
stroke_layer.putalpha(edges)
|
| 72 |
|
| 73 |
-
# Composite safely
|
| 74 |
out = Image.alpha_composite(base, fill_layer)
|
| 75 |
out = Image.alpha_composite(out, stroke_layer)
|
| 76 |
|
| 77 |
return out.convert("RGB")
|
| 78 |
|
| 79 |
-
def ensure_models_loaded(
|
| 80 |
global MODEL_CACHE
|
| 81 |
-
if MODEL_CACHE[
|
| 82 |
return
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# SAM2
|
|
|
|
| 88 |
base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
|
| 89 |
model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
sd = torch.load(final_path, map_location="cpu")
|
| 92 |
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 93 |
|
| 94 |
# PLM
|
| 95 |
-
plm_path = download_if_needed(repo_id,
|
| 96 |
plm = PLMLanguageAdapter(
|
| 97 |
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
|
| 98 |
transformer_dim=model.sam_mask_decoder.transformer_dim,
|
|
@@ -104,7 +119,7 @@ def ensure_models_loaded(stage):
|
|
| 104 |
plm.load_state_dict(plm_sd["plm"], strict=True)
|
| 105 |
plm.eval()
|
| 106 |
|
| 107 |
-
MODEL_CACHE[
|
| 108 |
|
| 109 |
# ----------------- GPU Inference -----------------
|
| 110 |
|
|
@@ -205,7 +220,11 @@ with gr.Blocks(title="SAM2 + PLM Segmentation") as demo:
|
|
| 205 |
text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical forceps'")
|
| 206 |
|
| 207 |
with gr.Row():
|
| 208 |
-
stage_select = gr.Radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
|
| 210 |
|
| 211 |
run_btn = gr.Button("Run Inference", variant="primary")
|
|
|
|
| 21 |
from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
|
| 22 |
|
| 23 |
# ----------------- Configuration -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
SAM2_CONFIG = "sam2_hiera_l.yaml"
|
|
|
|
| 25 |
BASE_CKPT_NAME = "sam2_hiera_large.pt"
|
|
|
|
|
|
|
| 26 |
|
| 27 |
SQUARE_DIM = 1024
|
| 28 |
logging.basicConfig(level=logging.INFO)
|
| 29 |
|
| 30 |
+
# Refactored to store specific filenames per model choice
|
| 31 |
+
MODEL_CONFIGS = {
|
| 32 |
+
"Stage 1": {
|
| 33 |
+
"repo_id": "aadarsh99/ConvSeg-Stage1",
|
| 34 |
+
"sam_filename": "fine_tuned_sam2_batched_100000.torch",
|
| 35 |
+
"plm_filename": "fine_tuned_sam2_batched_plm_100000.torch"
|
| 36 |
+
},
|
| 37 |
+
"Stage 2 (grad-acc: 4)": {
|
| 38 |
+
"repo_id": "aadarsh99/ConvSeg-Stage2",
|
| 39 |
+
"sam_filename": "fine_tuned_sam2_batched_60000.torch",
|
| 40 |
+
"plm_filename": "fine_tuned_sam2_batched_plm_60000.torch"
|
| 41 |
+
},
|
| 42 |
+
"Stage 2 (grad-acc: 8)": {
|
| 43 |
+
"repo_id": "aadarsh99/ConvSeg-Stage2",
|
| 44 |
+
"sam_filename": "fine_tuned_sam2_batched_100000.torch",
|
| 45 |
+
"plm_filename": "fine_tuned_sam2_batched_plm_100000.torch"
|
| 46 |
+
}
|
| 47 |
}
|
| 48 |
|
| 49 |
+
# Dynamically create cache keys based on config
|
| 50 |
+
MODEL_CACHE = {k: {"sam": None, "plm": None} for k in MODEL_CONFIGS.keys()}
|
| 51 |
+
|
| 52 |
# ----------------- Helper Functions -----------------
|
| 53 |
def download_if_needed(repo_id, filename):
|
| 54 |
try:
|
|
|
|
| 80 |
stroke_layer = Image.new("RGBA", base.size, color + (255,))
|
| 81 |
stroke_layer.putalpha(edges)
|
| 82 |
|
| 83 |
+
# Composite safely
|
| 84 |
out = Image.alpha_composite(base, fill_layer)
|
| 85 |
out = Image.alpha_composite(out, stroke_layer)
|
| 86 |
|
| 87 |
return out.convert("RGB")
|
| 88 |
|
| 89 |
+
def ensure_models_loaded(stage_key):
|
| 90 |
global MODEL_CACHE
|
| 91 |
+
if MODEL_CACHE[stage_key]["sam"] is not None:
|
| 92 |
return
|
| 93 |
+
|
| 94 |
+
config = MODEL_CONFIGS[stage_key]
|
| 95 |
+
repo_id = config["repo_id"]
|
| 96 |
+
|
| 97 |
+
logging.info(f"Loading {stage_key} models from {repo_id} into CPU RAM...")
|
| 98 |
|
| 99 |
# SAM2
|
| 100 |
+
# Base model is always the same
|
| 101 |
base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
|
| 102 |
model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
|
| 103 |
+
|
| 104 |
+
# Load specific fine-tuned checkpoint
|
| 105 |
+
final_path = download_if_needed(repo_id, config["sam_filename"])
|
| 106 |
sd = torch.load(final_path, map_location="cpu")
|
| 107 |
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 108 |
|
| 109 |
# PLM
|
| 110 |
+
plm_path = download_if_needed(repo_id, config["plm_filename"])
|
| 111 |
plm = PLMLanguageAdapter(
|
| 112 |
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
|
| 113 |
transformer_dim=model.sam_mask_decoder.transformer_dim,
|
|
|
|
| 119 |
plm.load_state_dict(plm_sd["plm"], strict=True)
|
| 120 |
plm.eval()
|
| 121 |
|
| 122 |
+
MODEL_CACHE[stage_key]["sam"], MODEL_CACHE[stage_key]["plm"] = model, plm
|
| 123 |
|
| 124 |
# ----------------- GPU Inference -----------------
|
| 125 |
|
|
|
|
| 220 |
text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical forceps'")
|
| 221 |
|
| 222 |
with gr.Row():
|
| 223 |
+
stage_select = gr.Radio(
|
| 224 |
+
choices=list(MODEL_CONFIGS.keys()),
|
| 225 |
+
value="Stage 2 (grad-acc: 8)",
|
| 226 |
+
label="Model Stage"
|
| 227 |
+
)
|
| 228 |
threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
|
| 229 |
|
| 230 |
run_btn = gr.Button("Run Inference", variant="primary")
|