aadarsh99 commited on
Commit
6221bbd
·
1 Parent(s): 3d4a272

update app

Browse files
Files changed (1) hide show
  1. app.py +39 -20
app.py CHANGED
@@ -21,24 +21,34 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
21
  from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
22
 
23
  # ----------------- Configuration -----------------
24
- REPO_MAP = {
25
- "Stage 1": "aadarsh99/ConvSeg-Stage1",
26
- "Stage 2": "aadarsh99/ConvSeg-Stage2"
27
- }
28
  SAM2_CONFIG = "sam2_hiera_l.yaml"
29
-
30
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
31
- FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
32
- PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
33
 
34
  SQUARE_DIM = 1024
35
  logging.basicConfig(level=logging.INFO)
36
 
37
- MODEL_CACHE = {
38
- "Stage 1": {"sam": None, "plm": None},
39
- "Stage 2": {"sam": None, "plm": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  }
41
 
 
 
 
42
  # ----------------- Helper Functions -----------------
43
  def download_if_needed(repo_id, filename):
44
  try:
@@ -70,29 +80,34 @@ def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.
70
  stroke_layer = Image.new("RGBA", base.size, color + (255,))
71
  stroke_layer.putalpha(edges)
72
 
73
- # Composite safely (Module-level returns new images, no in-place None issues)
74
  out = Image.alpha_composite(base, fill_layer)
75
  out = Image.alpha_composite(out, stroke_layer)
76
 
77
  return out.convert("RGB")
78
 
79
- def ensure_models_loaded(stage):
80
  global MODEL_CACHE
81
- if MODEL_CACHE[stage]["sam"] is not None:
82
  return
83
-
84
- repo_id = REPO_MAP[stage]
85
- logging.info(f"Loading {stage} models from {repo_id} into CPU RAM...")
 
 
86
 
87
  # SAM2
 
88
  base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
89
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
90
- final_path = download_if_needed(repo_id, FINAL_CKPT_NAME)
 
 
91
  sd = torch.load(final_path, map_location="cpu")
92
  model.load_state_dict(sd.get("model", sd), strict=True)
93
 
94
  # PLM
95
- plm_path = download_if_needed(repo_id, PLM_CKPT_NAME)
96
  plm = PLMLanguageAdapter(
97
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
98
  transformer_dim=model.sam_mask_decoder.transformer_dim,
@@ -104,7 +119,7 @@ def ensure_models_loaded(stage):
104
  plm.load_state_dict(plm_sd["plm"], strict=True)
105
  plm.eval()
106
 
107
- MODEL_CACHE[stage]["sam"], MODEL_CACHE[stage]["plm"] = model, plm
108
 
109
  # ----------------- GPU Inference -----------------
110
 
@@ -205,7 +220,11 @@ with gr.Blocks(title="SAM2 + PLM Segmentation") as demo:
205
  text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical forceps'")
206
 
207
  with gr.Row():
208
- stage_select = gr.Radio(choices=["Stage 1", "Stage 2"], value="Stage 1", label="Model Stage")
 
 
 
 
209
  threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
210
 
211
  run_btn = gr.Button("Run Inference", variant="primary")
 
21
  from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
22
 
23
  # ----------------- Configuration -----------------
 
 
 
 
24
  SAM2_CONFIG = "sam2_hiera_l.yaml"
 
25
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
 
 
26
 
27
  SQUARE_DIM = 1024
28
  logging.basicConfig(level=logging.INFO)
29
 
30
+ # Refactored to store specific filenames per model choice
31
+ MODEL_CONFIGS = {
32
+ "Stage 1": {
33
+ "repo_id": "aadarsh99/ConvSeg-Stage1",
34
+ "sam_filename": "fine_tuned_sam2_batched_100000.torch",
35
+ "plm_filename": "fine_tuned_sam2_batched_plm_100000.torch"
36
+ },
37
+ "Stage 2 (grad-acc: 4)": {
38
+ "repo_id": "aadarsh99/ConvSeg-Stage2",
39
+ "sam_filename": "fine_tuned_sam2_batched_60000.torch",
40
+ "plm_filename": "fine_tuned_sam2_batched_plm_60000.torch"
41
+ },
42
+ "Stage 2 (grad-acc: 8)": {
43
+ "repo_id": "aadarsh99/ConvSeg-Stage2",
44
+ "sam_filename": "fine_tuned_sam2_batched_100000.torch",
45
+ "plm_filename": "fine_tuned_sam2_batched_plm_100000.torch"
46
+ }
47
  }
48
 
49
+ # Dynamically create cache keys based on config
50
+ MODEL_CACHE = {k: {"sam": None, "plm": None} for k in MODEL_CONFIGS.keys()}
51
+
52
  # ----------------- Helper Functions -----------------
53
  def download_if_needed(repo_id, filename):
54
  try:
 
80
  stroke_layer = Image.new("RGBA", base.size, color + (255,))
81
  stroke_layer.putalpha(edges)
82
 
83
+ # Composite safely
84
  out = Image.alpha_composite(base, fill_layer)
85
  out = Image.alpha_composite(out, stroke_layer)
86
 
87
  return out.convert("RGB")
88
 
89
+ def ensure_models_loaded(stage_key):
90
  global MODEL_CACHE
91
+ if MODEL_CACHE[stage_key]["sam"] is not None:
92
  return
93
+
94
+ config = MODEL_CONFIGS[stage_key]
95
+ repo_id = config["repo_id"]
96
+
97
+ logging.info(f"Loading {stage_key} models from {repo_id} into CPU RAM...")
98
 
99
  # SAM2
100
+ # Base model is always the same
101
  base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
102
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
103
+
104
+ # Load specific fine-tuned checkpoint
105
+ final_path = download_if_needed(repo_id, config["sam_filename"])
106
  sd = torch.load(final_path, map_location="cpu")
107
  model.load_state_dict(sd.get("model", sd), strict=True)
108
 
109
  # PLM
110
+ plm_path = download_if_needed(repo_id, config["plm_filename"])
111
  plm = PLMLanguageAdapter(
112
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
113
  transformer_dim=model.sam_mask_decoder.transformer_dim,
 
119
  plm.load_state_dict(plm_sd["plm"], strict=True)
120
  plm.eval()
121
 
122
+ MODEL_CACHE[stage_key]["sam"], MODEL_CACHE[stage_key]["plm"] = model, plm
123
 
124
  # ----------------- GPU Inference -----------------
125
 
 
220
  text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical forceps'")
221
 
222
  with gr.Row():
223
+ stage_select = gr.Radio(
224
+ choices=list(MODEL_CONFIGS.keys()),
225
+ value="Stage 2 (grad-acc: 8)",
226
+ label="Model Stage"
227
+ )
228
  threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
229
 
230
  run_btn = gr.Button("Run Inference", variant="primary")