""" SAM3 segmentation execution - Optimized & Simplified. Removed morphological filtering for maximum recall and speed. """ import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as patches from skimage.measure import regionprops import torch import torchvision import numpy as np import time from PIL import Image from ..config.schemas import ComponentRequest from ..config.dependencies import AnalysisDeps MAX_DIMENSION = 1024 # Speed optimization: Downscale large images MIN_SOLIDITY = 0.50 MIN_CIRCULARITY = 0.1 # Use /tmp for all outputs OUTPUT_DIR = "/tmp" def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str: """ Execute SAM3 segmentation for the given component request. """ t_start = time.time() text_prompt = f"{request.color} {request.entity}" print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).") # 1. Load Image try: raw_image = Image.open(deps.image_path).convert("RGB") except Exception as e: return f"Error loading image: {e}" # 2. Resize image if too large (Critical for speed) w, h = raw_image.size scale_factor = 1.0 if max(w, h) > MAX_DIMENSION: scale_factor = MAX_DIMENSION / max(w, h) new_w = int(w * scale_factor) new_h = int(h * scale_factor) raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS) print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)") w, h = new_w, new_h # 3. Convert normalized coords (0-1000) to pixel coords sam_input_boxes = [] for box in request.bboxes: y_min = (box.ymin / 1000) * h x_min = (box.xmin / 1000) * w y_max = (box.ymax / 1000) * h x_max = (box.xmax / 1000) * w sam_input_boxes.append([x_min, y_min, x_max, y_max]) if not sam_input_boxes: return "No valid boxes provided." safe_label = f"{request.entity}".replace(" ", "_").lower() plot_filename = f"/tmp/out_{safe_label}.png" data_filename = f"/tmp/data_{safe_label}.npz" if deps.sam_model is None or deps.sam_processor is None: return f"[Mock] Would segment '{text_prompt}'." # 4. Inference print("[Engine] Running Inference...") t_inf = time.time() sam_input_labels = [[1] * len(sam_input_boxes)] input_boxes_batch = [sam_input_boxes] inputs = deps.sam_processor( images=raw_image, text=text_prompt, input_boxes=input_boxes_batch, input_boxes_labels=sam_input_labels, return_tensors="pt" ).to(deps.device) with torch.inference_mode(): outputs = deps.sam_model(**inputs) results = deps.sam_processor.post_process_instance_segmentation( outputs, threshold=0.3, target_sizes=inputs["original_sizes"].tolist() )[0] print(f"[Engine] Inference took {time.time() - t_inf:.2f}s") # 5. REMOVED: Morphological Filtering (Solidity/Circularity) # Morphology filtering keep_indices_morph = [] for i, mask_tensor in enumerate(results["masks"]): mask_np = mask_tensor.cpu().numpy() mask_np = np.squeeze(mask_np).astype(int) if mask_np.ndim != 2: keep_indices_morph.append(False) continue props = regionprops(mask_np) if not props: keep_indices_morph.append(False) continue prop = props[0] perimeter = prop.perimeter circularity = (4 * np.pi * prop.area) / (perimeter ** 2) if perimeter > 0 else 0 is_solid = prop.solidity > MIN_SOLIDITY is_round_enough = circularity > MIN_CIRCULARITY keep_indices_morph.append(is_solid and is_round_enough) if any(keep_indices_morph): keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device) before_count = len(results["masks"]) results = _filter_results(results, keep_indices_tensor) print(f"[Filter] Morphology: Dropped {before_count - len(results['masks'])} debris-like objects.") # 6. NMS (Keep this to remove duplicate detections on the same object) pred_boxes = results["boxes"] pred_scores = results["scores"] if len(pred_scores) > 1: keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3) results = _filter_results(results, keep_indices_nms) # 7. Save outputs _save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename) mask_count = len(results['masks']) if mask_count > 0: masks_list = [m.cpu().numpy().squeeze() for m in results['masks']] masks_array = np.array(masks_list) np.savez_compressed(data_filename, masks=masks_array) else: np.savez_compressed(data_filename, masks=np.array([])) total_time = time.time() - t_start print(f"[Engine] ✅ Done in {total_time:.2f}s. Saved {mask_count} masks.") return f"SUCCESS: Found {mask_count} '{text_prompt}' objects. MASK_FILE={data_filename} PLOT_FILE={plot_filename}" def _filter_results(results, keep_indices): """Helper to slice all dictionary keys at once.""" results["masks"] = results["masks"][keep_indices] results["scores"] = results["scores"][keep_indices] results["boxes"] = results["boxes"][keep_indices] return results def _save_plot(image, results, boxes, label, filename): """Save visualization of segmentation results with bounding boxes.""" fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) # 1. Draw Segmentation Masks if len(results['scores']) > 0: H, W = results['masks'][0].shape[-2:] composite = np.zeros((H, W, 4)) for mask, score in zip(results['masks'], results['scores']): if score > 0.3: m = mask.cpu().numpy().squeeze() color = np.random.random(3) # Add color to mask for c in range(3): composite[:, :, c] = np.maximum(composite[:, :, c], m * color[c]) composite[:, :, 3] = np.maximum(composite[:, :, 3], m * 0.5) ax.imshow(composite) ax.axis('off') # Save tightly to remove whitespace fig.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close(fig)