File size: 6,378 Bytes
201a9d0
869a58c
 
201a9d0
 
 
 
ba2f58e
dafd3a4
201a9d0
 
 
4792708
201a9d0
 
 
 
 
869a58c
201a9d0
0b858e3
 
 
 
4792708
201a9d0
 
 
 
 
 
4792708
869a58c
8560169
201a9d0
 
4792708
201a9d0
 
 
 
 
869a58c
4792708
 
 
 
 
 
 
 
 
 
 
201a9d0
 
4792708
 
 
 
201a9d0
 
 
 
 
2feca28
201a9d0
 
 
 
4792708
201a9d0
4792708
 
 
 
201a9d0
 
 
 
 
 
 
 
 
 
 
4792708
201a9d0
 
 
 
 
 
 
4792708
 
201a9d0
869a58c
9402dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4792708
869a58c
201a9d0
 
 
 
 
 
869a58c
201a9d0
 
 
 
 
 
 
 
 
 
4792708
 
201a9d0
 
 
 
 
 
 
 
 
 
 
 
 
ba2f58e
201a9d0
 
4792708
ba2f58e
4792708
 
 
 
 
 
 
 
 
 
 
 
 
 
201a9d0
 
ba2f58e
 
 
4792708
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
SAM3 segmentation execution - Optimized & Simplified.
Removed morphological filtering for maximum recall and speed.
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.measure import regionprops
import torch
import torchvision
import numpy as np
import time
from PIL import Image

from ..config.schemas import ComponentRequest
from ..config.dependencies import AnalysisDeps

MAX_DIMENSION = 1024  # Speed optimization: Downscale large images

MIN_SOLIDITY = 0.50
MIN_CIRCULARITY = 0.1


# Use /tmp for all outputs
OUTPUT_DIR = "/tmp"

def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
    """
    Execute SAM3 segmentation for the given component request.
    """
    t_start = time.time()
    
    text_prompt = f"{request.color} {request.entity}"
    print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")

    # 1. Load Image
    try:
        raw_image = Image.open(deps.image_path).convert("RGB")
    except Exception as e:
        return f"Error loading image: {e}"

    # 2. Resize image if too large (Critical for speed)
    w, h = raw_image.size
    scale_factor = 1.0
    if max(w, h) > MAX_DIMENSION:
        scale_factor = MAX_DIMENSION / max(w, h)
        new_w = int(w * scale_factor)
        new_h = int(h * scale_factor)
        raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
        print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
        w, h = new_w, new_h

    # 3. Convert normalized coords (0-1000) to pixel coords
    sam_input_boxes = []
    for box in request.bboxes:
        y_min = (box.ymin / 1000) * h
        x_min = (box.xmin / 1000) * w
        y_max = (box.ymax / 1000) * h
        x_max = (box.xmax / 1000) * w
        sam_input_boxes.append([x_min, y_min, x_max, y_max])

    if not sam_input_boxes:
        return "No valid boxes provided."

    safe_label = f"{request.entity}".replace(" ", "_").lower()
    plot_filename = f"/tmp/out_{safe_label}.png"
    data_filename = f"/tmp/data_{safe_label}.npz"
    
    if deps.sam_model is None or deps.sam_processor is None:
        return f"[Mock] Would segment '{text_prompt}'."

    # 4. Inference
    print("[Engine] Running Inference...")
    t_inf = time.time()
    
    sam_input_labels = [[1] * len(sam_input_boxes)]
    input_boxes_batch = [sam_input_boxes]

    inputs = deps.sam_processor(
        images=raw_image,
        text=text_prompt,
        input_boxes=input_boxes_batch,
        input_boxes_labels=sam_input_labels,
        return_tensors="pt"
    ).to(deps.device)

    with torch.inference_mode():
        outputs = deps.sam_model(**inputs)

    results = deps.sam_processor.post_process_instance_segmentation(
        outputs,
        threshold=0.3,
        target_sizes=inputs["original_sizes"].tolist()
    )[0]
    
    print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")

    # 5. REMOVED: Morphological Filtering (Solidity/Circularity)

    # Morphology filtering
    keep_indices_morph = []
    for i, mask_tensor in enumerate(results["masks"]):
        mask_np = mask_tensor.cpu().numpy()
        mask_np = np.squeeze(mask_np).astype(int)

        if mask_np.ndim != 2:
            keep_indices_morph.append(False)
            continue

        props = regionprops(mask_np)
        if not props:
            keep_indices_morph.append(False)
            continue

        prop = props[0]
        perimeter = prop.perimeter
        circularity = (4 * np.pi * prop.area) / (perimeter ** 2) if perimeter > 0 else 0

        is_solid = prop.solidity > MIN_SOLIDITY
        is_round_enough = circularity > MIN_CIRCULARITY
        keep_indices_morph.append(is_solid and is_round_enough)

    if any(keep_indices_morph):
        keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device)
        before_count = len(results["masks"])
        results = _filter_results(results, keep_indices_tensor)
        print(f"[Filter] Morphology: Dropped {before_count - len(results['masks'])} debris-like objects.")

    # 6. NMS (Keep this to remove duplicate detections on the same object)
    pred_boxes = results["boxes"]
    pred_scores = results["scores"]
    if len(pred_scores) > 1:
        keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
        results = _filter_results(results, keep_indices_nms)

    # 7. Save outputs
    _save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)

    mask_count = len(results['masks'])
    if mask_count > 0:
        masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
        masks_array = np.array(masks_list)
        np.savez_compressed(data_filename, masks=masks_array)
    else:
        np.savez_compressed(data_filename, masks=np.array([]))

    total_time = time.time() - t_start
    print(f"[Engine] ✅ Done in {total_time:.2f}s. Saved {mask_count} masks.")
    
    return f"SUCCESS: Found {mask_count} '{text_prompt}' objects. MASK_FILE={data_filename} PLOT_FILE={plot_filename}"


def _filter_results(results, keep_indices):
    """Helper to slice all dictionary keys at once."""
    results["masks"] = results["masks"][keep_indices]
    results["scores"] = results["scores"][keep_indices]
    results["boxes"] = results["boxes"][keep_indices]
    return results


def _save_plot(image, results, boxes, label, filename):
    """Save visualization of segmentation results with bounding boxes."""
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)
    
    # 1. Draw Segmentation Masks
    if len(results['scores']) > 0:
        H, W = results['masks'][0].shape[-2:]
        composite = np.zeros((H, W, 4))
        
        for mask, score in zip(results['masks'], results['scores']):
            if score > 0.3:
                m = mask.cpu().numpy().squeeze()
                color = np.random.random(3)
                # Add color to mask
                for c in range(3):
                    composite[:, :, c] = np.maximum(composite[:, :, c], m * color[c])
                composite[:, :, 3] = np.maximum(composite[:, :, 3], m * 0.5)
        
        ax.imshow(composite)

    ax.axis('off')
    
    # Save tightly to remove whitespace
    fig.savefig(filename, bbox_inches='tight', pad_inches=0)
    plt.close(fig)