Steph254's picture
Upload 26 files
b9e0048 verified
import cv2
import numpy as np
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
from app.utils.preprocess import preprocess_image
from app.utils.feature_extraction import extract_piece_features
class PuzzlePieceSegmenter:
def __init__(self):
# Load models
self.yolo = YOLO("yolov8n.pt") # Lightweight YOLOv8 model
self.sam = sam_model_registry["vit_b"](checkpoint="mobile_sam.pth")
self.predictor = SamPredictor(self.sam)
self.min_piece_area = 500
self.max_piece_area = 50000
def segment_pieces(self, image: np.ndarray) -> list:
# Preprocess image
processed = preprocess_image(image)
# YOLOv8 for coarse detection
results = self.yolo(processed)
boxes = results[0].boxes.xyxy.cpu().numpy()
# SAM for fine segmentation
self.predictor.set_image(processed)
pieces = []
piece_id = 0
for box in boxes:
masks, _, _ = self.predictor.predict(box_coordinates=box, multimask_output=False)
for mask in masks:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
area = cv2.contourArea(contour)
if self.min_piece_area <= area <= self.max_piece_area:
# Extract features and image for this piece
piece_image, features = extract_piece_features(processed, contour)
if features:
pieces.append({
'id': piece_id,
'image': cv2.imencode('.jpg', piece_image)[1].tobytes(),
'features': features
})
piece_id += 1
return pieces