File size: 13,123 Bytes
97c4d82
 
 
 
eb03911
3d4a272
97c4d82
 
 
 
 
c9aa521
0a57513
39ba93b
97c4d82
 
c9aa521
97c4d82
 
 
 
 
c9aa521
 
 
 
97c4d82
 
37ff04d
 
0a57513
97c4d82
 
 
c9aa521
 
6221bbd
7e46975
c84ea63
39ba93b
3d4a272
c84ea63
39ba93b
c84ea63
97c4d82
 
 
c9aa521
7e46975
3d4a272
7e46975
97c4d82
 
3d4a272
97c4d82
 
c84ea63
c9aa521
7e46975
3d4a272
97c4d82
 
c9aa521
c84ea63
 
3d4a272
 
97c4d82
c9aa521
3d4a272
 
 
b6001f7
c9aa521
c84ea63
c9aa521
3d4a272
6221bbd
c9aa521
3d4a272
c9aa521
 
39ba93b
6221bbd
c9aa521
 
461d3a3
3d4a272
0a57513
c9aa521
3d4a272
 
 
 
 
 
 
 
 
0a57513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6001f7
3d4a272
c9aa521
 
 
97c4d82
3d4a272
c84ea63
9d1694a
17751a0
 
7e46975
97c4d82
17751a0
 
b485573
 
 
17751a0
 
c9aa521
 
 
3d4a272
c9aa521
3d4a272
 
9d1694a
aeaa431
c84ea63
 
 
7e46975
3d4a272
c9aa521
7e46975
 
c84ea63
 
3d4a272
 
 
c9aa521
c84ea63
 
 
 
c9aa521
3d4a272
 
c9aa521
17751a0
3d4a272
c9aa521
3d4a272
 
 
 
 
 
 
 
 
 
c84ea63
aeaa431
c9aa521
7e46975
3d4a272
c9aa521
3d4a272
7e46975
3d4a272
 
c84ea63
c9aa521
3d4a272
 
96c10ec
3d4a272
17751a0
 
3d4a272
 
96c10ec
3d4a272
 
c9aa521
aeaa431
c9aa521
3d4a272
 
461d3a3
97c4d82
17751a0
c9aa521
7e46975
 
 
 
17751a0
 
 
7e46975
c9aa521
 
 
 
 
 
 
 
 
 
 
 
17751a0
 
b1ff212
17751a0
 
b1ff212
 
 
 
 
 
 
 
 
 
 
 
 
17751a0
c9aa521
 
 
 
 
 
 
 
 
7e46975
b1ff212
 
 
 
 
 
 
c9aa521
3d4a272
7e46975
c9aa521
 
 
 
 
 
 
97c4d82
c9aa521
 
 
3d4a272
17751a0
 
c9aa521
17751a0
 
 
 
 
b1ff212
 
17751a0
 
 
 
 
 
 
 
 
3d4a272
c9aa521
 
 
 
 
 
 
 
 
 
 
 
 
97c4d82
c9aa521
b1ff212
 
 
c9aa521
 
 
b1ff212
 
 
c9aa521
b1ff212
 
 
 
 
 
 
 
c9aa521
 
 
 
 
97c4d82
 
c9aa521
7e46975
 
 
c9aa521
7e46975
 
 
 
97c4d82
 
 
c9aa521
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import os
import logging
import hashlib
import sys
import traceback
import tempfile
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from PIL import Image, ImageFilter, ImageChops
from huggingface_hub import hf_hub_download, snapshot_download # <--- Added snapshot_download
import spaces

# --- IMPORT YOUR CUSTOM MODULES ---
# Ensure these files are present in your file structure
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter

# ----------------- Configuration -----------------
logging.basicConfig(level=logging.INFO)

# Single Model Configuration
REPO_ID = "aadarsh99/ConvSeg-Stage2"
SAM2_CONFIG = "sam2_hiera_l.yaml"
BASE_CKPT_NAME = "sam2_hiera_large.pt"
FINE_TUNED_SAM = "fine_tuned_sam2_batched_90000.torch"
FINE_TUNED_PLM = "fine_tuned_sam2_batched_plm_90000.torch"
FINE_TUNED_LORA = "lora_plm_adapter_90000" # Folder name in the HF Repo

SQUARE_DIM = 1024

# Global Cache
MODEL_CACHE = {"sam": None, "plm": None}

# ----------------- Helper Functions -----------------
def download_if_needed(repo_id, filename):
    try:
        logging.info(f"Checking {filename} in {repo_id}...")
        return hf_hub_download(repo_id=repo_id, filename=filename)
    except Exception as e:
        raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")

def stable_color(key: str):
    h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
    # Bright, distinct colors for overlays
    EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
    colors = [tuple(int(c.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) for c in EDGE_COLORS_HEX]
    return colors[h % len(colors)]

def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
    base = Image.fromarray(rgb.astype(np.uint8)).convert("RGBA")
    mask_bool = mask > 0
    color = stable_color(key)
    
    # Fill layer (Semi-transparent)
    fill_layer = Image.new("RGBA", base.size, color + (0,))
    fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 140), "L")
    fill_layer.putalpha(fill_alpha)

    # Stroke/Edge layer
    m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
    edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
    stroke_layer = Image.new("RGBA", base.size, color + (255,))
    stroke_layer.putalpha(edges)

    # Composite
    out = Image.alpha_composite(base, fill_layer)
    out = Image.alpha_composite(out, stroke_layer)
    return out.convert("RGB")

def ensure_models_loaded():
    global MODEL_CACHE
    if MODEL_CACHE["sam"] is not None: 
        return
    
    logging.info(f"Loading models from {REPO_ID}...")
    
    # 1. Load SAM2 Base & Fine-tuned weights
    base_path = download_if_needed(REPO_ID, BASE_CKPT_NAME)
    model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
    
    sam_ckpt_path = download_if_needed(REPO_ID, FINE_TUNED_SAM)
    sd = torch.load(sam_ckpt_path, map_location="cpu")
    model.load_state_dict(sd.get("model", sd), strict=True)
    
    # 2. Load PLM Adapter Base
    plm_path = download_if_needed(REPO_ID, FINE_TUNED_PLM)
    plm = PLMLanguageAdapter(
        model_name="Qwen/Qwen2.5-VL-3B-Instruct", 
        transformer_dim=model.sam_mask_decoder.transformer_dim, 
        n_sparse_tokens=0, use_dense_bias=True, use_lora=True, 
        lora_r=16, lora_alpha=32, lora_dropout=0.05, 
        dtype=torch.bfloat16, device="cpu"
    )
    plm_sd = torch.load(plm_path, map_location="cpu")
    plm.load_state_dict(plm_sd["plm"], strict=True)
    
    # 3. Load LoRA Weights
    try:
        logging.info(f"Downloading LoRA folder: {FINE_TUNED_LORA}...")
        # snapshot_download returns the root cache folder; we use allow_patterns to get just the LoRA folder
        cache_root = snapshot_download(repo_id=REPO_ID, allow_patterns=f"{FINE_TUNED_LORA}/*")
        
        # Construct the full path to the directory containing the LoRA files
        lora_dir_path = os.path.join(cache_root, FINE_TUNED_LORA)
        
        logging.info(f"Loading LoRA from {lora_dir_path}...")
        plm.load_lora(lora_dir_path)
    except Exception as e:
        raise RuntimeError(f"Failed to load LoRA weights: {e}")

    plm.eval()
    
    MODEL_CACHE["sam"] = model
    MODEL_CACHE["plm"] = plm
    logging.info("Models loaded successfully.")

# ----------------- GPU Inference -----------------

@spaces.GPU(duration=120) 
def run_prediction(image_pil, user_text, threshold=0.5):
    if image_pil is None or not user_text:
        return None, None, None

    # --- Prepend the required prefix ---
    full_prompt = f"Segment the {user_text.strip()}"
    # remove trailing punctuation for consistency
    if full_prompt[-1] in {".", "!", "?"}:
        full_prompt = full_prompt[:-1]
    logging.info(f"Processing prompt: {full_prompt}")

    ensure_models_loaded()
    sam_model = MODEL_CACHE["sam"]
    plm_model = MODEL_CACHE["plm"]
    
    # Move to GPU
    sam_model.to("cuda")
    plm_model.to("cuda")
    
    try:
        with torch.inference_mode():
            predictor = SAM2ImagePredictor(sam_model)
            rgb_orig = np.array(image_pil.convert("RGB"))
            H, W = rgb_orig.shape[:2]
            
            # Smart Resizing & Padding
            scale = SQUARE_DIM / max(H, W)
            nw, nh = int(W * scale), int(H * scale)
            top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2

            rgb_sq = cv2.resize(rgb_orig, (nw, nh), interpolation=cv2.INTER_LINEAR)
            rgb_sq = cv2.copyMakeBorder(rgb_sq, top, SQUARE_DIM-nh-top, left, SQUARE_DIM-nw-left, cv2.BORDER_CONSTANT, value=0)
            
            # Image Encoder
            predictor.set_image(rgb_sq)
            image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
            hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]

            # PLM Adapter (Text + Image processing)
            with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp:
                image_pil.save(tmp.name)
                # Qwen/PLM processes the text prompt here
                sp, dp = plm_model([full_prompt], image_emb.shape[2], image_emb.shape[3], [tmp.name])

            # SAM2 Mask Decoder
            dec = sam_model.sam_mask_decoder
            dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
            
            low, scores, _, _ = dec(
                image_embeddings=image_emb.to(dev, dtype),
                image_pe=sam_model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
                sparse_prompt_embeddings=sp.to(dev, dtype),
                dense_prompt_embeddings=dp.to(dev, dtype),
                multimask_output=True, repeat_image=False,
                high_res_features=[h.to(dev, dtype) for h in hi]
            )

            # Post-processing
            logits = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
            best_idx = scores.argmax().item()
            
            logit_crop = logits[0, best_idx, top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
            logit_full = F.interpolate(logit_crop, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
            
            prob = torch.sigmoid(logit_full).float().cpu().numpy()

        # Visuals
        heatmap_cv = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
        heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
        
        mask = (prob > threshold).astype(np.uint8) * 255
        # Use full_prompt for key to ensure consistent colors
        overlay = make_overlay(rgb_orig, mask, key=full_prompt)
        
        return overlay, Image.fromarray(heatmap_rgb), prob

    except Exception:
        traceback.print_exc()
        raise gr.Error("Inference failed. Please check logs.")
    finally:
        # Cleanup memory
        sam_model.to("cpu")
        plm_model.to("cpu")
        torch.cuda.empty_cache()

def update_threshold_ui(image_pil, user_text, threshold, cached_prob):
    """Real-time update using CPU only (no GPU quota usage)."""
    if image_pil is None or cached_prob is None:
        return None
    rgb_orig = np.array(image_pil.convert("RGB"))
    mask = (cached_prob > threshold).astype(np.uint8) * 255
    # Reconstruct full prompt to maintain consistent color hashing
    full_prompt = f"Segment the {user_text.strip()}" if user_text else "mask"
    return make_overlay(rgb_orig, mask, key=full_prompt)

# ----------------- UI Styling & Layout -----------------

custom_css = """
h1 {
    text-align: center;
    display: block;
}
.subtitle {
    text-align: center;
    font-size: 1.1em;
    margin-bottom: 20px;
}
.prefix-container {
    display: flex;
    align-items: center; 
    justify-content: center;
    height: 100%;
    /* Match Gradio Textbox font style */
    font-family: var(--font-sans); 
    font-size: var(--input-text-size);
    font-weight: 400;
    color: var(--body-text-color);
}
/* Force the HTML container to match height of neighbor */
.prefix-box {
    display: flex;
    flex-direction: column;
    justify-content: center;
    height: 100% !important;
    min-height: 42px; /* Standard Gradio input height fallback */
}
"""

theme = gr.themes.Soft(
    primary_hue="blue",
    neutral_hue="slate",
).set(
    button_primary_background_fill="*primary_600",
    button_primary_background_fill_hover="*primary_700",
)

def example_handler(text):
    """Callback to strip the prefix when an example is clicked"""
    prefix = "Segment the "
    if text and text.startswith(prefix):
        return text[len(prefix):]
    return text

with gr.Blocks(theme=theme, css=custom_css, title="ConvSeg-Net Demo") as demo:
    prob_state = gr.State()
    
    # Header
    gr.Markdown("# 🧩 Conversational Image Segmentation")
    gr.Markdown(
        "<div class='subtitle'>Grounding abstract concepts and physics-based reasoning into pixel-accurate masks.<br>"
        "Powered by <b>SAM2 + Qwen2.5-VL</b></div>"
    )

    with gr.Row():
        # --- Left Column: Inputs ---
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Input Image", height=400)
            
            # Custom prompt input layout
            gr.Markdown("**Conversational Prompt**")
            with gr.Group():
                with gr.Row(equal_height=True):
                    # Fixed Prefix
                    gr.HTML(
                        "<div class='prefix-container'>Segment the</div>", 
                        elem_classes="prefix-box",
                        min_width=100,
                        max_width=100
                    )
                    # User Input
                    text_prompt = gr.Textbox(
                        show_label=False,
                        container=False, 
                        placeholder="object that is prone to rolling...",
                        lines=1,
                        scale=5
                    )
            
            with gr.Accordion("⚙️ Advanced Options", open=False):
                threshold_slider = gr.Slider(
                    0.0, 1.0, value=0.5, step=0.01, 
                    label="Mask Confidence Threshold",
                    info="Adjust after running to refine the mask edges."
                )

            run_btn = gr.Button("🚀 Run Segmentation", variant="primary", size="lg")

        # --- Right Column: Outputs ---
        with gr.Column(scale=1):
            out_overlay = gr.Image(label="Segmentation Result", type="pil")
            out_heatmap = gr.Image(label="Confidence Heatmap", type="pil")

    # --- Examples Section ---
    # Hidden textbox to capture the full prompt from the example gallery
    hidden_example_text = gr.Textbox(visible=False)

    gr.Markdown("### 📝 Try Examples")
    gr.Examples(
        examples=[
            ["./examples/elephants.png", "Segment the elephant acting as the vanguard of the herd."],
            ["./examples/luggage.png", "Segment the luggage resting precariously."],
            ["./examples/veggies.png", "Segment the produce harvested from underground."],
        ],
        inputs=[input_image, hidden_example_text], # Output full text to hidden box
    )

    # When hidden box updates (from click), strip the prefix and update the visible box
    hidden_example_text.change(
        fn=example_handler,
        inputs=hidden_example_text,
        outputs=text_prompt
    )

    # --- Event Handling ---
    
    # 1. Run Inference (GPU)
    run_btn.click(
        fn=run_prediction,
        inputs=[input_image, text_prompt, threshold_slider],
        outputs=[out_overlay, out_heatmap, prob_state]
    )

    # 2. Update Threshold (CPU - Instant)
    threshold_slider.change(
        fn=update_threshold_ui,
        inputs=[input_image, text_prompt, threshold_slider, prob_state],
        outputs=[out_overlay]
    )

if __name__ == "__main__":
    demo.queue().launch()