Spaces:
Sleeping
Sleeping
File size: 6,378 Bytes
201a9d0 869a58c 201a9d0 ba2f58e dafd3a4 201a9d0 4792708 201a9d0 869a58c 201a9d0 0b858e3 4792708 201a9d0 4792708 869a58c 8560169 201a9d0 4792708 201a9d0 869a58c 4792708 201a9d0 4792708 201a9d0 2feca28 201a9d0 4792708 201a9d0 4792708 201a9d0 4792708 201a9d0 4792708 201a9d0 869a58c 9402dd3 4792708 869a58c 201a9d0 869a58c 201a9d0 4792708 201a9d0 ba2f58e 201a9d0 4792708 ba2f58e 4792708 201a9d0 ba2f58e 4792708 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
SAM3 segmentation execution - Optimized & Simplified.
Removed morphological filtering for maximum recall and speed.
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.measure import regionprops
import torch
import torchvision
import numpy as np
import time
from PIL import Image
from ..config.schemas import ComponentRequest
from ..config.dependencies import AnalysisDeps
MAX_DIMENSION = 1024 # Speed optimization: Downscale large images
MIN_SOLIDITY = 0.50
MIN_CIRCULARITY = 0.1
# Use /tmp for all outputs
OUTPUT_DIR = "/tmp"
def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
"""
Execute SAM3 segmentation for the given component request.
"""
t_start = time.time()
text_prompt = f"{request.color} {request.entity}"
print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
# 1. Load Image
try:
raw_image = Image.open(deps.image_path).convert("RGB")
except Exception as e:
return f"Error loading image: {e}"
# 2. Resize image if too large (Critical for speed)
w, h = raw_image.size
scale_factor = 1.0
if max(w, h) > MAX_DIMENSION:
scale_factor = MAX_DIMENSION / max(w, h)
new_w = int(w * scale_factor)
new_h = int(h * scale_factor)
raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
w, h = new_w, new_h
# 3. Convert normalized coords (0-1000) to pixel coords
sam_input_boxes = []
for box in request.bboxes:
y_min = (box.ymin / 1000) * h
x_min = (box.xmin / 1000) * w
y_max = (box.ymax / 1000) * h
x_max = (box.xmax / 1000) * w
sam_input_boxes.append([x_min, y_min, x_max, y_max])
if not sam_input_boxes:
return "No valid boxes provided."
safe_label = f"{request.entity}".replace(" ", "_").lower()
plot_filename = f"/tmp/out_{safe_label}.png"
data_filename = f"/tmp/data_{safe_label}.npz"
if deps.sam_model is None or deps.sam_processor is None:
return f"[Mock] Would segment '{text_prompt}'."
# 4. Inference
print("[Engine] Running Inference...")
t_inf = time.time()
sam_input_labels = [[1] * len(sam_input_boxes)]
input_boxes_batch = [sam_input_boxes]
inputs = deps.sam_processor(
images=raw_image,
text=text_prompt,
input_boxes=input_boxes_batch,
input_boxes_labels=sam_input_labels,
return_tensors="pt"
).to(deps.device)
with torch.inference_mode():
outputs = deps.sam_model(**inputs)
results = deps.sam_processor.post_process_instance_segmentation(
outputs,
threshold=0.3,
target_sizes=inputs["original_sizes"].tolist()
)[0]
print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")
# 5. REMOVED: Morphological Filtering (Solidity/Circularity)
# Morphology filtering
keep_indices_morph = []
for i, mask_tensor in enumerate(results["masks"]):
mask_np = mask_tensor.cpu().numpy()
mask_np = np.squeeze(mask_np).astype(int)
if mask_np.ndim != 2:
keep_indices_morph.append(False)
continue
props = regionprops(mask_np)
if not props:
keep_indices_morph.append(False)
continue
prop = props[0]
perimeter = prop.perimeter
circularity = (4 * np.pi * prop.area) / (perimeter ** 2) if perimeter > 0 else 0
is_solid = prop.solidity > MIN_SOLIDITY
is_round_enough = circularity > MIN_CIRCULARITY
keep_indices_morph.append(is_solid and is_round_enough)
if any(keep_indices_morph):
keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device)
before_count = len(results["masks"])
results = _filter_results(results, keep_indices_tensor)
print(f"[Filter] Morphology: Dropped {before_count - len(results['masks'])} debris-like objects.")
# 6. NMS (Keep this to remove duplicate detections on the same object)
pred_boxes = results["boxes"]
pred_scores = results["scores"]
if len(pred_scores) > 1:
keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
results = _filter_results(results, keep_indices_nms)
# 7. Save outputs
_save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
mask_count = len(results['masks'])
if mask_count > 0:
masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
masks_array = np.array(masks_list)
np.savez_compressed(data_filename, masks=masks_array)
else:
np.savez_compressed(data_filename, masks=np.array([]))
total_time = time.time() - t_start
print(f"[Engine] ✅ Done in {total_time:.2f}s. Saved {mask_count} masks.")
return f"SUCCESS: Found {mask_count} '{text_prompt}' objects. MASK_FILE={data_filename} PLOT_FILE={plot_filename}"
def _filter_results(results, keep_indices):
"""Helper to slice all dictionary keys at once."""
results["masks"] = results["masks"][keep_indices]
results["scores"] = results["scores"][keep_indices]
results["boxes"] = results["boxes"][keep_indices]
return results
def _save_plot(image, results, boxes, label, filename):
"""Save visualization of segmentation results with bounding boxes."""
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
# 1. Draw Segmentation Masks
if len(results['scores']) > 0:
H, W = results['masks'][0].shape[-2:]
composite = np.zeros((H, W, 4))
for mask, score in zip(results['masks'], results['scores']):
if score > 0.3:
m = mask.cpu().numpy().squeeze()
color = np.random.random(3)
# Add color to mask
for c in range(3):
composite[:, :, c] = np.maximum(composite[:, :, c], m * color[c])
composite[:, :, 3] = np.maximum(composite[:, :, 3], m * 0.5)
ax.imshow(composite)
ax.axis('off')
# Save tightly to remove whitespace
fig.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close(fig) |