Spaces:
Sleeping
Sleeping
| """ | |
| SAM3 segmentation execution - Optimized & Simplified. | |
| Removed morphological filtering for maximum recall and speed. | |
| """ | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from skimage.measure import regionprops | |
| import torch | |
| import torchvision | |
| import numpy as np | |
| import time | |
| from PIL import Image | |
| from ..config.schemas import ComponentRequest | |
| from ..config.dependencies import AnalysisDeps | |
| MAX_DIMENSION = 1024 # Speed optimization: Downscale large images | |
| MIN_SOLIDITY = 0.50 | |
| MIN_CIRCULARITY = 0.1 | |
| # Use /tmp for all outputs | |
| OUTPUT_DIR = "/tmp" | |
| def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str: | |
| """ | |
| Execute SAM3 segmentation for the given component request. | |
| """ | |
| t_start = time.time() | |
| text_prompt = f"{request.color} {request.entity}" | |
| print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).") | |
| # 1. Load Image | |
| try: | |
| raw_image = Image.open(deps.image_path).convert("RGB") | |
| except Exception as e: | |
| return f"Error loading image: {e}" | |
| # 2. Resize image if too large (Critical for speed) | |
| w, h = raw_image.size | |
| scale_factor = 1.0 | |
| if max(w, h) > MAX_DIMENSION: | |
| scale_factor = MAX_DIMENSION / max(w, h) | |
| new_w = int(w * scale_factor) | |
| new_h = int(h * scale_factor) | |
| raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)") | |
| w, h = new_w, new_h | |
| # 3. Convert normalized coords (0-1000) to pixel coords | |
| sam_input_boxes = [] | |
| for box in request.bboxes: | |
| y_min = (box.ymin / 1000) * h | |
| x_min = (box.xmin / 1000) * w | |
| y_max = (box.ymax / 1000) * h | |
| x_max = (box.xmax / 1000) * w | |
| sam_input_boxes.append([x_min, y_min, x_max, y_max]) | |
| if not sam_input_boxes: | |
| return "No valid boxes provided." | |
| safe_label = f"{request.entity}".replace(" ", "_").lower() | |
| plot_filename = f"/tmp/out_{safe_label}.png" | |
| data_filename = f"/tmp/data_{safe_label}.npz" | |
| if deps.sam_model is None or deps.sam_processor is None: | |
| return f"[Mock] Would segment '{text_prompt}'." | |
| # 4. Inference | |
| print("[Engine] Running Inference...") | |
| t_inf = time.time() | |
| sam_input_labels = [[1] * len(sam_input_boxes)] | |
| input_boxes_batch = [sam_input_boxes] | |
| inputs = deps.sam_processor( | |
| images=raw_image, | |
| text=text_prompt, | |
| input_boxes=input_boxes_batch, | |
| input_boxes_labels=sam_input_labels, | |
| return_tensors="pt" | |
| ).to(deps.device) | |
| with torch.inference_mode(): | |
| outputs = deps.sam_model(**inputs) | |
| results = deps.sam_processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=0.3, | |
| target_sizes=inputs["original_sizes"].tolist() | |
| )[0] | |
| print(f"[Engine] Inference took {time.time() - t_inf:.2f}s") | |
| # 5. REMOVED: Morphological Filtering (Solidity/Circularity) | |
| # Morphology filtering | |
| keep_indices_morph = [] | |
| for i, mask_tensor in enumerate(results["masks"]): | |
| mask_np = mask_tensor.cpu().numpy() | |
| mask_np = np.squeeze(mask_np).astype(int) | |
| if mask_np.ndim != 2: | |
| keep_indices_morph.append(False) | |
| continue | |
| props = regionprops(mask_np) | |
| if not props: | |
| keep_indices_morph.append(False) | |
| continue | |
| prop = props[0] | |
| perimeter = prop.perimeter | |
| circularity = (4 * np.pi * prop.area) / (perimeter ** 2) if perimeter > 0 else 0 | |
| is_solid = prop.solidity > MIN_SOLIDITY | |
| is_round_enough = circularity > MIN_CIRCULARITY | |
| keep_indices_morph.append(is_solid and is_round_enough) | |
| if any(keep_indices_morph): | |
| keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device) | |
| before_count = len(results["masks"]) | |
| results = _filter_results(results, keep_indices_tensor) | |
| print(f"[Filter] Morphology: Dropped {before_count - len(results['masks'])} debris-like objects.") | |
| # 6. NMS (Keep this to remove duplicate detections on the same object) | |
| pred_boxes = results["boxes"] | |
| pred_scores = results["scores"] | |
| if len(pred_scores) > 1: | |
| keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3) | |
| results = _filter_results(results, keep_indices_nms) | |
| # 7. Save outputs | |
| _save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename) | |
| mask_count = len(results['masks']) | |
| if mask_count > 0: | |
| masks_list = [m.cpu().numpy().squeeze() for m in results['masks']] | |
| masks_array = np.array(masks_list) | |
| np.savez_compressed(data_filename, masks=masks_array) | |
| else: | |
| np.savez_compressed(data_filename, masks=np.array([])) | |
| total_time = time.time() - t_start | |
| print(f"[Engine] ✅ Done in {total_time:.2f}s. Saved {mask_count} masks.") | |
| return f"SUCCESS: Found {mask_count} '{text_prompt}' objects. MASK_FILE={data_filename} PLOT_FILE={plot_filename}" | |
| def _filter_results(results, keep_indices): | |
| """Helper to slice all dictionary keys at once.""" | |
| results["masks"] = results["masks"][keep_indices] | |
| results["scores"] = results["scores"][keep_indices] | |
| results["boxes"] = results["boxes"][keep_indices] | |
| return results | |
| def _save_plot(image, results, boxes, label, filename): | |
| """Save visualization of segmentation results with bounding boxes.""" | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| ax.imshow(image) | |
| # 1. Draw Segmentation Masks | |
| if len(results['scores']) > 0: | |
| H, W = results['masks'][0].shape[-2:] | |
| composite = np.zeros((H, W, 4)) | |
| for mask, score in zip(results['masks'], results['scores']): | |
| if score > 0.3: | |
| m = mask.cpu().numpy().squeeze() | |
| color = np.random.random(3) | |
| # Add color to mask | |
| for c in range(3): | |
| composite[:, :, c] = np.maximum(composite[:, :, c], m * color[c]) | |
| composite[:, :, 3] = np.maximum(composite[:, :, 3], m * 0.5) | |
| ax.imshow(composite) | |
| ax.axis('off') | |
| # Save tightly to remove whitespace | |
| fig.savefig(filename, bbox_inches='tight', pad_inches=0) | |
| plt.close(fig) |