Spaces:
Sleeping
Sleeping
File size: 4,958 Bytes
f7fbe80 df3522f f7fbe80 df3522f f7fbe80 df3522f f7fbe80 df3522f f7fbe80 df3522f f7fbe80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
from PIL import Image
from pathlib import Path
import supervision as sv
# from ultralytics import YOLO
from doclayout_yolo import YOLOv10
from src.config import YOLO_MODEL_PATH, log
from src.utils import binarize
log.debug("loading YOLO Model")
LAYOUT_MODEL = YOLOv10(YOLO_MODEL_PATH)
CLASSES = [
"Caption",
"Footnote",
"Formula",
"List-item",
"Page-footer",
"Page-header",
"Picture",
"Section-header",
"Table",
"Text",
"Title",
"Unknown",
]
# Define a custom color palette for each class
CLASS_COLORS = [
sv.Color(255, 0, 0), # Red for "Caption"
sv.Color(0, 255, 0), # Green for "Footnote"
sv.Color(0, 0, 255), # Blue for "Formula"
sv.Color(255, 255, 0), # Yellow for "List-item"
sv.Color(255, 0, 255), # Magenta for "Page-footer"
sv.Color(0, 255, 255), # Cyan for "Page-header"
sv.Color(128, 0, 128), # Purple for "Picture"
sv.Color(128, 128, 0), # Olive for "Section-header"
sv.Color(128, 128, 128), # Gray for "Table"
sv.Color(0, 128, 128), # Teal for "Text"
sv.Color(128, 0, 0), # Maroon for "Title"
sv.Color(255, 255, 255), # Maroon for "Unknown"
]
# Initialize the BoxAnnotator with the custom color palette and increased thickness
box_annotator = sv.BoxAnnotator(
color=sv.ColorPalette(CLASS_COLORS),
thickness=2, # Increased thickness for bounding boxes
)
# Initialize the LabelAnnotator with custom background and text colors
label_annotator = sv.LabelAnnotator(
color=sv.ColorPalette(CLASS_COLORS), # Background colors matching bounding boxes
text_color=sv.Color(255, 255, 255), # White text for better readability
)
def detect(img, conf=0.2, iou=0.8, labels=False, plot=False, model=LAYOUT_MODEL):
# Object detection on image
results = model(img, conf=conf, iou=iou, verbose=False)[0]
# Convert results to detections
detections = sv.Detections.from_ultralytics(results)
# Annotate the image with bounding boxes
annotated_image = box_annotator.annotate(scene=img, detections=detections)
if labels:
# Annotate the image with labels
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections
)
if plot:
sv.plot_image(annotated_image)
return detections
def sort_coords(xyxy):
return xyxy[np.argsort(xyxy[:, 0])]
def get_labels_with_confidence(detections):
return list(
zip(detections.data["class_name"].tolist(), detections.confidence.tolist())
)
def rectangle_area(*coords):
bottom_left, top_right = coords[:2], coords[2:]
x1, y1 = bottom_left
x2, y2 = top_right
width = abs(x2 - x1)
height = abs(y2 - y1)
return width * height
def get_mask(img_h, img_w, detections, class_colors=CLASS_COLORS, inc=50):
size = (img_h, img_w)
areas = np.array([float(rectangle_area(*row)) for row in detections.xyxy])
order = np.argsort(areas)[::-1]
mask_arr = np.zeros((*size, 3), dtype=np.uint8)
coords = np.round(detections.xyxy).astype(np.int32)
for row, class_id in zip(coords[order], detections.class_id[order]):
a, b, c, d = row.tolist()
b, d = size[0] - b, size[0] - d
rgb = class_colors[class_id].as_rgb()
mask_arr[d : b + 1, a : c + 1, :] = rgb
mask = Image.fromarray(mask_arr[::-1])
return mask
def merge_masks(masks, weights=None):
if len(masks) == 1:
raise ValueError("more than 1 mask needed to merge")
masks = [np.array(mask) for mask in masks]
if weights is None:
weights = [1 / len(masks)] * len(masks)
final_mask = np.zeros_like(masks[0], dtype=float)
for mask, weight in zip(masks, weights):
final_mask += mask * weight
return Image.fromarray(final_mask.astype("uint8"))
def get_merged_mask(
fp: Path = None,
img: Image = None,
conf=0.2,
iou=0.8,
inc=50,
mask_weights=None,
resize_dim=512,
model=LAYOUT_MODEL,
binarized=False,
):
if (fp is None) == (img is None):
raise ValueError(f"only one of `fp` or `img` is required")
if fp:
log.debug(f"getting merged mask for file {fp}")
img = Image.open(fp)
images, masks = [], []
images.extend(
[
img,
img.convert("L"),
# any other extra transformations
]
)
if mask_weights is not None:
mask_weights /= mask_weights.sum()
for image in images:
dc = detect(image, conf=conf, iou=iou, plot=False, model=model)
masks.append(get_mask(*image.size[::-1], dc, inc=inc))
merged_masks = merge_masks(masks, weights=mask_weights)
if binarized:
merged_masks[merged_masks > 0] = 255
final_mask = merged_masks.resize((resize_dim, resize_dim))
if binarized:
final_mask_arr = binarize(np.array(final_mask)) * 255
final_mask = Image.fromarray(final_mask_arr)
return final_mask
|