Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import cv2 | |
| from models.utils import save_mask_as_png | |
| from database.crud import create_session | |
| def resize_if_needed(image, max_side): | |
| """ | |
| Resize image if any side > max_side, keeping aspect ratio. | |
| :param image: np.ndarray, loaded BGR image | |
| :param max_side: int, max allowed size for width or height | |
| :return: resized image | |
| """ | |
| h, w = image.shape[:2] | |
| if max(h, w) <= max_side: | |
| print("[DEBUG] The image has fine size") | |
| return image # already fine | |
| scale = max_side / max(h, w) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| print("[DEBUG] The image was resized") | |
| return resized | |
| def process_session_image(session_id, image_path, prompt_text, sam_wrapper, dino_wrapper, save_root="outputs"): | |
| """ | |
| Full pipeline: detect + segment + save + write to DB (session-based). | |
| :param session_id: ID of the session (string) | |
| :param image_path: Path to uploaded image (string) | |
| :param prompt_text: Prompt from user (string) | |
| :param sam_wrapper: Initialized SAM wrapper | |
| :param dino_wrapper: Initialized DINO wrapper | |
| :param save_root: Base output directory (default: "outputs") | |
| :return: List of saved PNG file paths | |
| """ | |
| image = cv2.imread(image_path) | |
| image = resize_if_needed(image, max_side=1536) | |
| if image is None: | |
| raise ValueError(f"Failed to load image from path: {image_path}") | |
| # 1. Run DINO detection | |
| boxes = dino_wrapper.detect(image, prompt_text) | |
| # 2. Create output folder for this session | |
| session_dir = os.path.join(save_root, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| saved_paths = [] | |
| # 3. Run SAM on each box | |
| for i, box in enumerate(boxes): | |
| mask = sam_wrapper.predict_with_box(image, box) | |
| if mask is None: | |
| continue | |
| filename = f"{uuid.uuid4().hex[:8]}_{i}_{prompt_text.replace(' ', '_')}.png" | |
| full_path = os.path.join(session_dir, filename) | |
| save_mask_as_png(image, mask, full_path) | |
| relative_path = os.path.relpath(full_path, start=".").replace("\\", "/") | |
| saved_paths.append(relative_path) | |
| # 4. Save session in database | |
| create_session(session_id=session_id, image_path=image_path, result_paths=saved_paths) | |
| return saved_paths | |