enpaiva's picture
Update app.py
8304cf1 verified
raw
history blame
12.6 kB
import os
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
import sys
import torch
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from transformers import (
DFineForObjectDetection,
RTDetrV2ForObjectDetection,
RTDetrImageProcessor,
)
# == select device ==
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Available models with their corresponding model classes
MODELS = {
"Egret XLarge": {
"path": "ds4sd/docling-layout-egret-xlarge",
"model_class": DFineForObjectDetection
},
"Egret Large": {
"path": "ds4sd/docling-layout-egret-large",
"model_class": DFineForObjectDetection
},
"Egret Medium": {
"path": "ds4sd/docling-layout-egret-medium",
"model_class": DFineForObjectDetection
},
"Heron 101": {
"path": "ds4sd/docling-layout-heron-101",
"model_class": RTDetrV2ForObjectDetection
},
"Heron": {
"path": "ds4sd/docling-layout-heron",
"model_class": RTDetrV2ForObjectDetection
}
}
# Classes mapping for the docling model
classes_map = {
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
11: "Document Index",
12: "Code",
13: "Checkbox-Selected",
14: "Checkbox-Unselected",
15: "Form",
16: "Key-Value Region",
}
# Global variables for model
current_model = None
current_processor = None
current_model_name = None
def colormap(N=256, normalized=False):
"""Generate the color map."""
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << (7 - j))
g = g | (bitget(c, 1) << (7 - j))
b = b | (bitget(c, 2) << (7 - j))
c = c >> 3
cmap[i] = np.array([r, g, b])
if normalized:
cmap = cmap.astype(np.float32) / 255.0
return cmap
def iomin(box1, box2):
"""Intersection over Minimum (IoMin)"""
x1 = torch.max(box1[:, 0], box2[:, 0])
y1 = torch.max(box1[:, 1], box2[:, 1])
x2 = torch.min(box1[:, 2], box2[:, 2])
y2 = torch.min(box1[:, 3], box2[:, 3])
inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
min_area = torch.min(box1_area, box2_area)
return inter_area / min_area
def nms(boxes, scores, iou_threshold=0.5):
"""Custom NMS implementation using IoMin"""
keep = []
_, order = scores.sort(descending=True)
while order.numel() > 0:
i = order[0]
keep.append(i.item())
if order.numel() == 1:
break
box_i = boxes[i].unsqueeze(0)
rest = order[1:]
ious = iomin(box_i, boxes[rest])
mask = (ious <= iou_threshold)
order = order[1:][mask]
return torch.tensor(keep, dtype=torch.long)
def load_model(model_name):
"""Load the selected model"""
global current_model, current_processor, current_model_name
if current_model_name == model_name:
return f"βœ… Model {model_name} is already loaded!"
try:
print(f"Loading model: {model_name}")
model_info = MODELS[model_name]
model_path = model_info["path"]
model_class = model_info["model_class"]
processor = RTDetrImageProcessor.from_pretrained(model_path)
model = model_class.from_pretrained(model_path)
model = model.to(device)
model.eval()
current_processor = processor
current_model = model
current_model_name = model_name
return f"βœ… Successfully loaded {model_name}!"
except Exception as e:
return f"❌ Error loading {model_name}: {str(e)}"
def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3):
"""Visualize bounding boxes with transparent overlays using OpenCV"""
if isinstance(image_input, Image.Image):
image = np.array(image_input)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif isinstance(image_input, np.ndarray):
if len(image_input.shape) == 3 and image_input.shape[2] == 3:
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
else:
image = image_input.copy()
else:
raise ValueError("Input must be PIL Image or numpy array")
overlay = image.copy()
cmap = colormap(N=len(id_to_names), normalized=False)
if len(bboxes) == 0:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
for i in range(len(bboxes)):
try:
bbox = bboxes[i]
if torch.is_tensor(bbox):
bbox = bbox.cpu().numpy()
class_id = classes[i]
if torch.is_tensor(class_id):
class_id = class_id.item()
score = scores[i]
if torch.is_tensor(score):
score = score.item()
x_min, y_min, x_max, y_max = map(int, bbox)
class_id = int(class_id)
class_name = id_to_names.get(class_id, f"unknown_{class_id}")
text = f"{class_name}:{score:.3f}"
color = tuple(int(c) for c in cmap[class_id % len(cmap)])
cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1)
cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
except Exception as e:
print(f"Skipping box {i} due to error: {e}")
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def recognize_image(input_img, conf_threshold, iou_threshold, nms_method, alpha):
"""Process image with docling layout model"""
if input_img is None:
return None, "Please upload an image first."
if current_model is None or current_processor is None:
return None, "Please load a model first."
try:
if isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img)
if input_img.mode != 'RGB':
input_img = input_img.convert('RGB')
inputs = current_processor(images=[input_img], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = current_model(**inputs)
results = current_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([input_img.size[::-1]]),
threshold=conf_threshold,
)
if not results or len(results) == 0:
return np.array(input_img), "No detections found."
result = results[0]
boxes = result["boxes"]
scores = result["scores"]
labels = result["labels"]
if len(boxes) == 0:
return np.array(input_img), "No detections above confidence threshold."
if iou_threshold < 1.0:
if nms_method == "Custom IoMin":
keep_indices = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
else:
keep_indices = torch.ops.torchvision.nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]
if len(boxes.shape) == 1:
boxes = boxes.unsqueeze(0)
scores = scores.unsqueeze(0)
labels = labels.unsqueeze(0)
output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha)
detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})"
return output, detection_info
except Exception as e:
print(f"[ERROR] recognize_image failed: {e}")
error_msg = f"Error during processing: {str(e)}"
if input_img is not None:
return np.array(input_img), error_msg
return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
def gradio_reset():
return gr.update(value=None), gr.update(value=None), gr.update(value="")
if __name__ == "__main__":
print(f"Using device: {device}")
# Custom CSS for better scrolling and layout
custom_css = """
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
.main-content {
overflow-y: auto !important;
max-height: 100vh !important;
}
"""
with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft(), css=custom_css) as demo:
# Header
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>πŸ” Document Layout Analysis</h1>
<p>Using Docling Layout Models for document structure detection</p>
</div>
""")
with gr.Row():
# Left Column - Controls
with gr.Column(scale=1):
# Model selection
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Egret XLarge",
label="πŸ€– Select Model"
)
load_btn = gr.Button("πŸ“₯ Load Model", variant="secondary", size="sm")
model_status = gr.Textbox(label="Model Status", interactive=False, value="No model loaded", max_lines=2)
input_img = gr.Image(label="πŸ“„ Upload Image", type="pil", height=300)
with gr.Row():
clear = gr.Button("πŸ—‘οΈ Clear", size="sm")
predict = gr.Button("πŸ” Detect", variant="primary", size="sm")
# Parameters
conf_threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Confidence Threshold")
iou_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="NMS IoU Threshold")
nms_method = gr.Radio(["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Method")
alpha_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.1, label="Overlay Transparency")
# Right Column - Results
with gr.Column(scale=1):
gr.HTML("<h3>🎯 Detection Results</h3>")
output_img = gr.Image(label="Detected Layout", interactive=False, type="numpy", height=400)
detection_info = gr.Textbox(label="Detection Info", interactive=False, max_lines=2)
# Legend at the bottom
with gr.Accordion("πŸ“‹ Detected Classes", open=False):
cmap = colormap(N=len(classes_map), normalized=False)
legend_items = []
for class_id, class_name in classes_map.items():
color_rgb = cmap[class_id % len(cmap)]
color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
legend_items.append(f'<span style="display:inline-block;width:15px;height:15px;background-color:{color_hex};margin-right:5px;border:1px solid #ccc;"></span>{class_name}')
legend_html = f"""
<div style='display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; font-size: 14px;'>
{''.join([f'<div>{item}</div>' for item in legend_items])}
</div>
"""
gr.HTML(legend_html)
# Event handlers
load_btn.click(load_model, inputs=[model_dropdown], outputs=[model_status])
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img, detection_info])
predict.click(
recognize_image,
inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider],
outputs=[output_img, detection_info]
)
# Launch
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, share=False)