CAD-AID / src /utils /segmentation.py
Julia Jørstad
First version
452a352
raw
history blame
2.51 kB
import numpy as np
import cv2
from shapely import Polygon
from shapely.wkt import dumps
from shapely.geometry import box
import pandas as pd
def predict_segments(model, image_path, conf = 0.5, classes = 0):
""" Predict segments using YOLOv8 segmentation model for floorplan drawings.
Args:
- model: Local stored segmentation model
- image_path: path to uploaded image
- conf: confidence threshold for the model
- classes: 0 (BRA, area inside exterior walls) or 1 (area including exterior walls)
"""
results = model.predict(image_path, conf=conf, classes = classes, retina_masks=True)
return results
def segmentation_to_binary(results):
"""
Converts
"""
binary_masks = []
for result in results:
if result.masks is not None:
for mask in result.masks.data:
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
binary_masks.append(mask_np)
return binary_masks
def segmask_to_pandas(seg_results):
"""
Stores the segmentation mask as polygons in dataframe with id
Args:
- seg_results (tensor object): segmentation results from YOLO
Returns:
- Pandas Dataframe: 'mask_id' column, 'polygon_coord' column with polygon coord.
"""
mask_count = 0
masks_list = []
for result in seg_results:
boxes = result.boxes.xyxy.cpu().numpy()
for mask,box in zip(result.masks.xy, boxes):
box = box.astype(int)
mask_pts = mask.astype(int).tolist()
masks_list.append({"mask_id": mask_count,"polygon": mask_pts, 'bboxes': box})
mask_count +=1
df = pd.DataFrame(masks_list)
return df
def fill_segments(mask,bboxes):
mask = np.uint8(mask>0)*255
bbox_mask = np.zeros_like(mask)
for x1,y1,x2,y2 in bboxes:
bbox_mask[y1:y2, x1:x2]
bbox_inv = cv2.bitwise_not(bbox_mask)
kernel = np.ones((3,3),np.uint8)
expanded_mask = cv2.dilate(mask, kernel, iterations=3)
final_mask = cv2.bitwise_and(expanded_mask,bbox_inv)
floodfill_mask = final_mask.copy()
h,w = floodfill_mask.shape
mask_floodfill = np.zeros((h+2,w+2), np.uint8)
cv2.floodFill(floodfill_mask, mask_floodfill, (0,0), 255)
floodfill_mask = cv2.bitwise_not(floodfill_mask)
final_filled_mask = cv2.bitwise_or(floodfill_mask, final_mask)
return final_filled_mask
def find_text_in_segments(masks, ocr_results):
pass