CAD-AID / src /utils /segmentation.py
Julia Jørstad
First version
452a352
import numpy as np
import cv2
from shapely import Polygon
from shapely.wkt import dumps
from shapely.geometry import box
import pandas as pd
def predict_segments(model, image_path, conf = 0.5, classes = 0):
""" Predict segments using YOLOv8 segmentation model for floorplan drawings.
Args:
- model: Local stored segmentation model
- image_path: path to uploaded image
- conf: confidence threshold for the model
- classes: 0 (BRA, area inside exterior walls) or 1 (area including exterior walls)
"""
results = model.predict(image_path, conf=conf, classes = classes, retina_masks=True)
return results
def segmentation_to_binary(results):
"""
Converts
"""
binary_masks = []
for result in results:
if result.masks is not None:
for mask in result.masks.data:
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
binary_masks.append(mask_np)
return binary_masks
def segmask_to_pandas(seg_results):
"""
Stores the segmentation mask as polygons in dataframe with id
Args:
- seg_results (tensor object): segmentation results from YOLO
Returns:
- Pandas Dataframe: 'mask_id' column, 'polygon_coord' column with polygon coord.
"""
mask_count = 0
masks_list = []
for result in seg_results:
boxes = result.boxes.xyxy.cpu().numpy()
for mask,box in zip(result.masks.xy, boxes):
box = box.astype(int)
mask_pts = mask.astype(int).tolist()
masks_list.append({"mask_id": mask_count,"polygon": mask_pts, 'bboxes': box})
mask_count +=1
df = pd.DataFrame(masks_list)
return df
def fill_segments(mask,bboxes):
mask = np.uint8(mask>0)*255
bbox_mask = np.zeros_like(mask)
for x1,y1,x2,y2 in bboxes:
bbox_mask[y1:y2, x1:x2]
bbox_inv = cv2.bitwise_not(bbox_mask)
kernel = np.ones((3,3),np.uint8)
expanded_mask = cv2.dilate(mask, kernel, iterations=3)
final_mask = cv2.bitwise_and(expanded_mask,bbox_inv)
floodfill_mask = final_mask.copy()
h,w = floodfill_mask.shape
mask_floodfill = np.zeros((h+2,w+2), np.uint8)
cv2.floodFill(floodfill_mask, mask_floodfill, (0,0), 255)
floodfill_mask = cv2.bitwise_not(floodfill_mask)
final_filled_mask = cv2.bitwise_or(floodfill_mask, final_mask)
return final_filled_mask
def find_text_in_segments(masks, ocr_results):
pass