img_comparer / src /doclayout.py
Vivek Vaddina
πŸ§‘β€πŸ’» New UI and refactor code
df3522f unverified
import numpy as np
from PIL import Image
from pathlib import Path
import supervision as sv
# from ultralytics import YOLO
from doclayout_yolo import YOLOv10
from src.config import YOLO_MODEL_PATH, log
from src.utils import binarize
log.debug("loading YOLO Model")
LAYOUT_MODEL = YOLOv10(YOLO_MODEL_PATH)
CLASSES = [
"Caption",
"Footnote",
"Formula",
"List-item",
"Page-footer",
"Page-header",
"Picture",
"Section-header",
"Table",
"Text",
"Title",
"Unknown",
]
# Define a custom color palette for each class
CLASS_COLORS = [
sv.Color(255, 0, 0), # Red for "Caption"
sv.Color(0, 255, 0), # Green for "Footnote"
sv.Color(0, 0, 255), # Blue for "Formula"
sv.Color(255, 255, 0), # Yellow for "List-item"
sv.Color(255, 0, 255), # Magenta for "Page-footer"
sv.Color(0, 255, 255), # Cyan for "Page-header"
sv.Color(128, 0, 128), # Purple for "Picture"
sv.Color(128, 128, 0), # Olive for "Section-header"
sv.Color(128, 128, 128), # Gray for "Table"
sv.Color(0, 128, 128), # Teal for "Text"
sv.Color(128, 0, 0), # Maroon for "Title"
sv.Color(255, 255, 255), # Maroon for "Unknown"
]
# Initialize the BoxAnnotator with the custom color palette and increased thickness
box_annotator = sv.BoxAnnotator(
color=sv.ColorPalette(CLASS_COLORS),
thickness=2, # Increased thickness for bounding boxes
)
# Initialize the LabelAnnotator with custom background and text colors
label_annotator = sv.LabelAnnotator(
color=sv.ColorPalette(CLASS_COLORS), # Background colors matching bounding boxes
text_color=sv.Color(255, 255, 255), # White text for better readability
)
def detect(img, conf=0.2, iou=0.8, labels=False, plot=False, model=LAYOUT_MODEL):
# Object detection on image
results = model(img, conf=conf, iou=iou, verbose=False)[0]
# Convert results to detections
detections = sv.Detections.from_ultralytics(results)
# Annotate the image with bounding boxes
annotated_image = box_annotator.annotate(scene=img, detections=detections)
if labels:
# Annotate the image with labels
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections
)
if plot:
sv.plot_image(annotated_image)
return detections
def sort_coords(xyxy):
return xyxy[np.argsort(xyxy[:, 0])]
def get_labels_with_confidence(detections):
return list(
zip(detections.data["class_name"].tolist(), detections.confidence.tolist())
)
def rectangle_area(*coords):
bottom_left, top_right = coords[:2], coords[2:]
x1, y1 = bottom_left
x2, y2 = top_right
width = abs(x2 - x1)
height = abs(y2 - y1)
return width * height
def get_mask(img_h, img_w, detections, class_colors=CLASS_COLORS, inc=50):
size = (img_h, img_w)
areas = np.array([float(rectangle_area(*row)) for row in detections.xyxy])
order = np.argsort(areas)[::-1]
mask_arr = np.zeros((*size, 3), dtype=np.uint8)
coords = np.round(detections.xyxy).astype(np.int32)
for row, class_id in zip(coords[order], detections.class_id[order]):
a, b, c, d = row.tolist()
b, d = size[0] - b, size[0] - d
rgb = class_colors[class_id].as_rgb()
mask_arr[d : b + 1, a : c + 1, :] = rgb
mask = Image.fromarray(mask_arr[::-1])
return mask
def merge_masks(masks, weights=None):
if len(masks) == 1:
raise ValueError("more than 1 mask needed to merge")
masks = [np.array(mask) for mask in masks]
if weights is None:
weights = [1 / len(masks)] * len(masks)
final_mask = np.zeros_like(masks[0], dtype=float)
for mask, weight in zip(masks, weights):
final_mask += mask * weight
return Image.fromarray(final_mask.astype("uint8"))
def get_merged_mask(
fp: Path = None,
img: Image = None,
conf=0.2,
iou=0.8,
inc=50,
mask_weights=None,
resize_dim=512,
model=LAYOUT_MODEL,
binarized=False,
):
if (fp is None) == (img is None):
raise ValueError(f"only one of `fp` or `img` is required")
if fp:
log.debug(f"getting merged mask for file {fp}")
img = Image.open(fp)
images, masks = [], []
images.extend(
[
img,
img.convert("L"),
# any other extra transformations
]
)
if mask_weights is not None:
mask_weights /= mask_weights.sum()
for image in images:
dc = detect(image, conf=conf, iou=iou, plot=False, model=model)
masks.append(get_mask(*image.size[::-1], dc, inc=inc))
merged_masks = merge_masks(masks, weights=mask_weights)
if binarized:
merged_masks[merged_masks > 0] = 255
final_mask = merged_masks.resize((resize_dim, resize_dim))
if binarized:
final_mask_arr = binarize(np.array(final_mask)) * 255
final_mask = Image.fromarray(final_mask_arr)
return final_mask