import os import io import json from pathlib import Path import requests import numpy as np import gradio as gr from PIL import Image, ImageDraw, ImageFont from pipeline import create_labelme_json, clean_labelme # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- # Hosted Ultralytics inference endpoint. Prefer setting these as Space secrets # (env vars); the values below are fallbacks so it runs out of the box. API_URL = os.getenv("API_URL") API_KEY = os.getenv("API_KEY") IMAGE_FOLDER = "images" # Color per class keyword (RGB) CLASS_COLORS = { 'column': (255, 165, 0), # orange 'row': (0, 200, 0), # green 'header': (30, 120, 255), # blue 'line': (230, 230, 0), # yellow } DEFAULT_COLOR = (255, 0, 0) # red def color_for_label(label): low = label.lower() for key, color in CLASS_COLORS.items(): if key in low: return color return DEFAULT_COLOR # --------------------------------------------------------------------------- # Test images # --------------------------------------------------------------------------- def get_test_images(): images = [] if os.path.exists(IMAGE_FOLDER): for file in sorted(Path(IMAGE_FOLDER).glob("*")): if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]: images.append((str(file), file.name)) return images def load_test_image(image_path): if image_path and os.path.exists(image_path): return Image.open(image_path).convert("RGB") return None # --------------------------------------------------------------------------- # Drawing # --------------------------------------------------------------------------- def _load_font(font_size): font_paths = [ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", "/System/Library/Fonts/Arial.ttf", "C:\\Windows\\Fonts\\arial.ttf", "arial.ttf", ] for path in font_paths: if os.path.exists(path): try: return ImageFont.truetype(path, font_size) except Exception: continue return ImageFont.load_default() def draw_shapes_on_image(image, shapes): """Draw cleaned labelme rectangle shapes onto a PIL image.""" if not shapes: return image img = image.copy() draw = ImageDraw.Draw(img) img_w, img_h = img.size min_dim = min(img_w, img_h) font_size = max(int(min_dim * 0.018), 16) line_width = max(int(min_dim * 0.004), 2) font = _load_font(font_size) for shape in shapes: a = np.array(shape["points"]) x1, y1 = int(np.min(a[:, 0])), int(np.min(a[:, 1])) x2, y2 = int(np.max(a[:, 0])), int(np.max(a[:, 1])) label = shape["label"] color = color_for_label(label) if x2 <= x1 or y2 <= y1: continue draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width) bbox = draw.textbbox((0, 0), label, font=font) tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] ty = max(0, y1 - th - 6) pad = 3 draw.rectangle([x1, ty, x1 + tw + 2 * pad, ty + th + 2 * pad], fill=(0, 0, 0)) draw.text((x1 + pad, ty + pad), label, font=font, fill=color) return img # --------------------------------------------------------------------------- # Prediction # --------------------------------------------------------------------------- def format_results(shapes, img_w, img_h): out = "## Detection Results\n\n" out += f"**Image Size:** {img_w} x {img_h} (W x H)\n\n" out += f"**Shapes Found:** {len(shapes)}\n\n" if shapes: out += "### Detected Objects\n" out += "| Label | Confidence |\n" out += "|-------|------------|\n" for s in shapes: desc = s.get("description", "") conf = desc.replace("confidence:", "").strip() if desc else "N/A" out += f"| {s['label']} | {conf} |\n" return out def call_api(image, confidence, iou, imgsz): """POST the image to the hosted Ultralytics endpoint and return the JSON.""" img_bytes = io.BytesIO() image.save(img_bytes, format="JPEG") img_bytes.seek(0) params = {"conf": confidence, "iou": iou, "imgsz": imgsz} headers = {"Authorization": f"Bearer {API_KEY}"} files = {"file": ("image.jpg", img_bytes, "image/jpeg")} response = requests.post(API_URL, headers=headers, data=params, files=files, timeout=60) response.raise_for_status() return response.json() def api_results_to_detections(api_result): """Convert the API response into the pipeline's detections dict.""" boxes = [] images = api_result.get("images", []) if isinstance(api_result, dict) else [] if images: for det in images[0].get("results", []): box = det.get("box", {}) x1 = float(box.get("x1", 0)) y1 = float(box.get("y1", 0)) x2 = float(box.get("x2", 0)) y2 = float(box.get("y2", 0)) boxes.append({ "points": [[x1, y1], [x2, y1], [x2, y2], [x1, y2]], "confidence": float(det.get("confidence", 0)), "class_name": det.get("name", "unknown"), "class_id": int(det.get("class", 0)), }) return {"boxes": boxes} def predict_image(image, confidence, iou, imgsz): if image is None: return None, None, "#### Please upload an image to begin detection" try: image = image.convert("RGB") api_result = call_api(image, float(confidence), float(iou), int(imgsz)) detections = api_results_to_detections(api_result) # Build + clean labelme JSON (rows span columns, columns span header->last row, dedupe) labelme_json = create_labelme_json( "image.png", detections, image.height, image.width) labelme_json = clean_labelme(labelme_json) shapes = labelme_json["shapes"] result_img = draw_shapes_on_image(image, shapes) report = format_results(shapes, image.width, image.height) json_path = os.path.join(os.getcwd(), "result.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(labelme_json, f, indent=2) return result_img, json_path, report except requests.exceptions.Timeout: return None, None, "#### Error: Request timeout. Please try again." except requests.exceptions.ConnectionError: return None, None, "#### Error: Unable to connect to detection service." except requests.exceptions.HTTPError as e: return None, None, f"#### Error: API returned status {e.response.status_code}" except Exception as e: return None, None, f"#### Error: {str(e)}" # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- dark_theme = gr.themes.Monochrome( primary_hue="slate", secondary_hue="slate", ).set( body_text_color="#e0e0e0", background_fill_primary="#0f0f0f", background_fill_secondary="#1a1a1a", ) with gr.Blocks(title="Table Layout Detection") as demo: gr.Markdown(""" # Table Layout Detection Detect table columns, rows and headers. Upload an image and adjust the inference parameters. Boxes are auto-cleaned (rows span all columns, columns span header→last row, duplicates removed) before being drawn. """) with gr.Row(): with gr.Column(scale=1, min_width=400): gr.Markdown("### Input") image_input = gr.Image(label="Image", type="pil", sources=["upload"], interactive=True) test_images = get_test_images() if test_images: test_image_radio = gr.Radio( choices=[img[1] for img in test_images], label="Select test image", info="Click to load", ) test_image_radio.change( fn=lambda name: load_test_image( next((img[0] for img in test_images if img[1] == name), None)), inputs=[test_image_radio], outputs=[image_input], ) else: gr.Markdown("No test images found. Add images to the 'images' folder.") gr.Markdown("### Configuration") confidence_slider = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.2, step=0.01, info="Detection confidence level") iou_slider = gr.Slider(label="IOU Threshold (NMS)", minimum=0.0, maximum=1.0, value=0.2, step=0.01, info="Intersection over union threshold") imgsz_slider = gr.Slider(label="Image Size", minimum=320, maximum=2048, value=1280, step=32, info="Inference image resolution") predict_btn = gr.Button("Detect Objects", variant="primary", size="lg") with gr.Column(scale=1, min_width=400): gr.Markdown("### Results") image_output = gr.Image(label="Detections", type="pil", interactive=False) json_output = gr.File(label="Download labelme JSON") results_output = gr.Markdown(value="Detection results will appear here.") inputs = [image_input, confidence_slider, iou_slider, imgsz_slider] outputs = [image_output, json_output, results_output] predict_btn.click(fn=predict_image, inputs=inputs, outputs=outputs) image_input.change(fn=predict_image, inputs=inputs, outputs=outputs) confidence_slider.change(fn=predict_image, inputs=inputs, outputs=outputs) iou_slider.change(fn=predict_image, inputs=inputs, outputs=outputs) imgsz_slider.change(fn=predict_image, inputs=inputs, outputs=outputs) if __name__ == "__main__": demo.launch(share=False, show_error=True, theme=dark_theme, css="footer {display: none !important;}")