enpaiva's picture
Create app.py
2020af8 verified
raw
history blame
13.1 kB
import os
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
import sys
import torch
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import (
DFineForObjectDetection,
RTDetrImageProcessor,
)
# == select device ==
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Available models
MODELS = {
"Egret XLarge": "ds4sd/docling-layout-egret-xlarge",
"Egret Large": "ds4sd/docling-layout-egret-large",
"Egret Medium": "ds4sd/docling-layout-egret-medium",
"Heron 101": "ds4sd/docling-layout-heron-101",
"Heron": "ds4sd/docling-layout-heron"
}
# Classes mapping for the docling model
classes_map = {
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
11: "Document Index",
12: "Code",
13: "Checkbox-Selected",
14: "Checkbox-Unselected",
15: "Form",
16: "Key-Value Region",
}
# Color mapping for visualization
colors = [
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57",
"#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43",
"#10AC84", "#EE5A24", "#0ABDE3", "#006BA6", "#F79F1F",
"#A3CB38", "#FDA7DF"
]
# Global variables for model
current_model = None
current_processor = None
current_model_name = None
def iomin(box1, box2):
"""
Intersection over Minimum (IoMin)
box1: Tensor[1, 4]
box2: Tensor[N, 4]
Returns: Tensor[N]
"""
# Intersection
x1 = torch.max(box1[:, 0], box2[:, 0])
y1 = torch.max(box1[:, 1], box2[:, 1])
x2 = torch.min(box1[:, 2], box2[:, 2])
y2 = torch.min(box1[:, 3], box2[:, 3])
inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
# Areas
box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
min_area = torch.min(box1_area, box2_area)
return inter_area / min_area
def nms(boxes, scores, iou_threshold=0.5):
"""
Custom NMS implementation using IoMin
"""
keep = []
_, order = scores.sort(descending=True)
while order.numel() > 0:
i = order[0]
keep.append(i.item())
if order.numel() == 1:
break
box_i = boxes[i].unsqueeze(0) # [1, 4]
rest = order[1:]
ious = iomin(box_i, boxes[rest])
mask = (ious <= iou_threshold)
order = order[1:][mask]
return torch.tensor(keep, dtype=torch.long)
def load_model(model_name):
"""
Load the selected model
"""
global current_model, current_processor, current_model_name
if current_model_name == model_name:
return f"βœ… Model {model_name} is already loaded!"
try:
print(f"Loading model: {model_name}")
model_path = MODELS[model_name]
processor = RTDetrImageProcessor.from_pretrained(model_path)
model = DFineForObjectDetection.from_pretrained(model_path)
model = model.to(device)
model.eval()
current_processor = processor
current_model = model
current_model_name = model_name
return f"βœ… Successfully loaded {model_name}!"
except Exception as e:
return f"❌ Error loading {model_name}: {str(e)}"
def visualize_bbox(image, boxes, labels, scores, classes_map, colors):
"""
Visualize bounding boxes on image
"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
raise ValueError("Input image must be PIL Image or numpy array")
# Create a copy to draw on
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image)
# Try to use a font, fallback to default if not available
try:
font = ImageFont.truetype("arial.ttf", 20)
except:
try:
font = ImageFont.load_default()
except:
font = None
for box, label_id, score in zip(boxes, labels, scores):
# Convert tensor to int if needed
if torch.is_tensor(label_id):
label_id = label_id.item()
if torch.is_tensor(score):
score = score.item()
label = classes_map.get(int(label_id), f"Class_{label_id}")
color = colors[int(label_id) % len(colors)]
# Convert box coordinates to integers
x1, y1, x2, y2 = [int(coord) for coord in box]
# Draw rectangle
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# Draw label background
text = f"{label}: {score:.2f}"
if font:
bbox = draw.textbbox((x1, y1), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
else:
# Estimate text size if no font available
text_width = len(text) * 10
text_height = 20
draw.rectangle([x1, y1-text_height-4, x1+text_width+4, y1], fill=color)
draw.text((x1+2, y1-text_height-2), text, fill="white", font=font)
return np.array(draw_image)
def recognize_image(input_img, conf_threshold, iou_threshold, nms_method):
"""
Process image with docling layout model
"""
if input_img is None:
return None, "Please upload an image first."
if current_model is None or current_processor is None:
return None, "Please load a model first."
try:
# Ensure image is PIL Image
if isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img)
# Convert to RGB if needed
if input_img.mode != 'RGB':
input_img = input_img.convert('RGB')
# Process image
inputs = current_processor(images=[input_img], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = current_model(**inputs)
# Post-process results
results = current_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([input_img.size[::-1]]),
threshold=conf_threshold,
)
if not results or len(results) == 0:
return np.array(input_img), "No detections found."
result = results[0]
# Get results
boxes = result["boxes"]
scores = result["scores"]
labels = result["labels"]
if len(boxes) == 0:
return np.array(input_img), "No detections above confidence threshold."
# Apply NMS if requested
if iou_threshold < 1.0:
if nms_method == "Custom IoMin":
# Use custom NMS with IoMin
keep_indices = nms(
boxes=boxes,
scores=scores,
iou_threshold=iou_threshold
)
else:
# Use standard torchvision NMS
keep_indices = torch.ops.torchvision.nms(
boxes=boxes,
scores=scores,
iou_threshold=iou_threshold
)
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]
# Handle single detection case
if len(boxes.shape) == 1:
boxes = boxes.unsqueeze(0)
scores = scores.unsqueeze(0)
labels = labels.unsqueeze(0)
# Visualize results
output = visualize_bbox(
input_img,
boxes,
labels,
scores,
classes_map,
colors
)
detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})"
return output, detection_info
except Exception as e:
print(f"[ERROR] recognize_image failed: {e}")
error_msg = f"Error during processing: {str(e)}"
# Return original image on error
if input_img is not None:
return np.array(input_img), error_msg
return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
def gradio_reset():
return gr.update(value=None), gr.update(value=None), gr.update(value="")
if __name__ == "__main__":
print(f"Using device: {device}")
# Create header HTML
header_html = """
<div style="text-align: center; margin-bottom: 20px;">
<h1>πŸ” Document Layout Analysis</h1>
<p>Using Docling Layout Models for document structure detection</p>
<p>Select a model, upload an image and adjust the parameters to detect document elements</p>
</div>
"""
with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft()) as demo:
gr.HTML(header_html)
with gr.Row():
with gr.Column():
# Model selection
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Egret XLarge",
label="πŸ€– Select Model",
info="Choose which Docling model to use"
)
load_btn = gr.Button("πŸ“₯ Load Model", variant="secondary")
model_status = gr.Textbox(
label="Model Status",
interactive=False,
value="No model loaded"
)
input_img = gr.Image(
label="πŸ“„ Upload Document Image",
interactive=True,
type="pil"
)
with gr.Row():
clear = gr.Button("πŸ—‘οΈ Clear")
predict = gr.Button("πŸ” Detect Layout", interactive=True, variant="primary")
with gr.Row():
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.6,
info="Minimum confidence score for detections"
)
with gr.Row():
iou_threshold = gr.Slider(
label="NMS IoU Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
info="IoU threshold for Non-Maximum Suppression"
)
nms_method = gr.Radio(
choices=["Custom IoMin", "Standard IoU"],
value="Custom IoMin",
label="NMS Method",
info="Choose NMS algorithm"
)
# Legend
with gr.Accordion("πŸ“‹ Detected Classes", open=False):
legend_html = "<div style='display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px;'>"
for class_id, class_name in classes_map.items():
color = colors[class_id % len(colors)]
legend_html += f"""
<div style='display: flex; align-items: center; padding: 5px;'>
<div style='width: 20px; height: 20px; background-color: {color}; margin-right: 10px; border: 1px solid #ccc;'></div>
<span>{class_name}</span>
</div>
"""
legend_html += "</div>"
gr.HTML(legend_html)
with gr.Column():
gr.HTML("<h3>🎯 Detection Results</h3>")
output_img = gr.Image(
label="Detected Layout",
interactive=False,
type="numpy"
)
detection_info = gr.Textbox(
label="Detection Info",
interactive=False,
value=""
)
# Event handlers
load_btn.click(
load_model,
inputs=[model_dropdown],
outputs=[model_status]
)
clear.click(
gradio_reset,
inputs=None,
outputs=[input_img, output_img, detection_info]
)
predict.click(
recognize_image,
inputs=[input_img, conf_threshold, iou_threshold, nms_method],
outputs=[output_img, detection_info]
)
# Launch the demo
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
share=False
)