Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| from pathlib import Path | |
| import requests | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| from pipeline import create_labelme_json, clean_labelme | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| # Hosted Ultralytics inference endpoint. Prefer setting these as Space secrets | |
| # (env vars); the values below are fallbacks so it runs out of the box. | |
| API_URL = os.getenv("API_URL") | |
| API_KEY = os.getenv("API_KEY") | |
| IMAGE_FOLDER = "images" | |
| # Color per class keyword (RGB) | |
| CLASS_COLORS = { | |
| 'column': (255, 165, 0), # orange | |
| 'row': (0, 200, 0), # green | |
| 'header': (30, 120, 255), # blue | |
| 'line': (230, 230, 0), # yellow | |
| } | |
| DEFAULT_COLOR = (255, 0, 0) # red | |
| def color_for_label(label): | |
| low = label.lower() | |
| for key, color in CLASS_COLORS.items(): | |
| if key in low: | |
| return color | |
| return DEFAULT_COLOR | |
| # --------------------------------------------------------------------------- | |
| # Test images | |
| # --------------------------------------------------------------------------- | |
| def get_test_images(): | |
| images = [] | |
| if os.path.exists(IMAGE_FOLDER): | |
| for file in sorted(Path(IMAGE_FOLDER).glob("*")): | |
| if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]: | |
| images.append((str(file), file.name)) | |
| return images | |
| def load_test_image(image_path): | |
| if image_path and os.path.exists(image_path): | |
| return Image.open(image_path).convert("RGB") | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Drawing | |
| # --------------------------------------------------------------------------- | |
| def _load_font(font_size): | |
| font_paths = [ | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", | |
| "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", | |
| "/System/Library/Fonts/Arial.ttf", | |
| "C:\\Windows\\Fonts\\arial.ttf", | |
| "arial.ttf", | |
| ] | |
| for path in font_paths: | |
| if os.path.exists(path): | |
| try: | |
| return ImageFont.truetype(path, font_size) | |
| except Exception: | |
| continue | |
| return ImageFont.load_default() | |
| def draw_shapes_on_image(image, shapes): | |
| """Draw cleaned labelme rectangle shapes onto a PIL image.""" | |
| if not shapes: | |
| return image | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| img_w, img_h = img.size | |
| min_dim = min(img_w, img_h) | |
| font_size = max(int(min_dim * 0.018), 16) | |
| line_width = max(int(min_dim * 0.004), 2) | |
| font = _load_font(font_size) | |
| for shape in shapes: | |
| a = np.array(shape["points"]) | |
| x1, y1 = int(np.min(a[:, 0])), int(np.min(a[:, 1])) | |
| x2, y2 = int(np.max(a[:, 0])), int(np.max(a[:, 1])) | |
| label = shape["label"] | |
| color = color_for_label(label) | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width) | |
| bbox = draw.textbbox((0, 0), label, font=font) | |
| tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| ty = max(0, y1 - th - 6) | |
| pad = 3 | |
| draw.rectangle([x1, ty, x1 + tw + 2 * pad, ty + th + 2 * pad], fill=(0, 0, 0)) | |
| draw.text((x1 + pad, ty + pad), label, font=font, fill=color) | |
| return img | |
| # --------------------------------------------------------------------------- | |
| # Prediction | |
| # --------------------------------------------------------------------------- | |
| def format_results(shapes, img_w, img_h): | |
| out = "## Detection Results\n\n" | |
| out += f"**Image Size:** {img_w} x {img_h} (W x H)\n\n" | |
| out += f"**Shapes Found:** {len(shapes)}\n\n" | |
| if shapes: | |
| out += "### Detected Objects\n" | |
| out += "| Label | Confidence |\n" | |
| out += "|-------|------------|\n" | |
| for s in shapes: | |
| desc = s.get("description", "") | |
| conf = desc.replace("confidence:", "").strip() if desc else "N/A" | |
| out += f"| {s['label']} | {conf} |\n" | |
| return out | |
| def call_api(image, confidence, iou, imgsz): | |
| """POST the image to the hosted Ultralytics endpoint and return the JSON.""" | |
| img_bytes = io.BytesIO() | |
| image.save(img_bytes, format="JPEG") | |
| img_bytes.seek(0) | |
| params = {"conf": confidence, "iou": iou, "imgsz": imgsz} | |
| headers = {"Authorization": f"Bearer {API_KEY}"} | |
| files = {"file": ("image.jpg", img_bytes, "image/jpeg")} | |
| response = requests.post(API_URL, headers=headers, data=params, files=files, timeout=60) | |
| response.raise_for_status() | |
| return response.json() | |
| def api_results_to_detections(api_result): | |
| """Convert the API response into the pipeline's detections dict.""" | |
| boxes = [] | |
| images = api_result.get("images", []) if isinstance(api_result, dict) else [] | |
| if images: | |
| for det in images[0].get("results", []): | |
| box = det.get("box", {}) | |
| x1 = float(box.get("x1", 0)) | |
| y1 = float(box.get("y1", 0)) | |
| x2 = float(box.get("x2", 0)) | |
| y2 = float(box.get("y2", 0)) | |
| boxes.append({ | |
| "points": [[x1, y1], [x2, y1], [x2, y2], [x1, y2]], | |
| "confidence": float(det.get("confidence", 0)), | |
| "class_name": det.get("name", "unknown"), | |
| "class_id": int(det.get("class", 0)), | |
| }) | |
| return {"boxes": boxes} | |
| def predict_image(image, confidence, iou, imgsz): | |
| if image is None: | |
| return None, None, "#### Please upload an image to begin detection" | |
| try: | |
| image = image.convert("RGB") | |
| api_result = call_api(image, float(confidence), float(iou), int(imgsz)) | |
| detections = api_results_to_detections(api_result) | |
| # Build + clean labelme JSON (rows span columns, columns span header->last row, dedupe) | |
| labelme_json = create_labelme_json( | |
| "image.png", detections, image.height, image.width) | |
| labelme_json = clean_labelme(labelme_json) | |
| shapes = labelme_json["shapes"] | |
| result_img = draw_shapes_on_image(image, shapes) | |
| report = format_results(shapes, image.width, image.height) | |
| json_path = os.path.join(os.getcwd(), "result.json") | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(labelme_json, f, indent=2) | |
| return result_img, json_path, report | |
| except requests.exceptions.Timeout: | |
| return None, None, "#### Error: Request timeout. Please try again." | |
| except requests.exceptions.ConnectionError: | |
| return None, None, "#### Error: Unable to connect to detection service." | |
| except requests.exceptions.HTTPError as e: | |
| return None, None, f"#### Error: API returned status {e.response.status_code}" | |
| except Exception as e: | |
| return None, None, f"#### Error: {str(e)}" | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| dark_theme = gr.themes.Monochrome( | |
| primary_hue="slate", | |
| secondary_hue="slate", | |
| ).set( | |
| body_text_color="#e0e0e0", | |
| background_fill_primary="#0f0f0f", | |
| background_fill_secondary="#1a1a1a", | |
| ) | |
| with gr.Blocks(title="Table Layout Detection") as demo: | |
| gr.Markdown(""" | |
| # Table Layout Detection | |
| Detect table columns, rows and headers. Upload an image and adjust the | |
| inference parameters. Boxes are auto-cleaned (rows span all columns, columns | |
| span header→last row, duplicates removed) before being drawn. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=400): | |
| gr.Markdown("### Input") | |
| image_input = gr.Image(label="Image", type="pil", sources=["upload"], interactive=True) | |
| test_images = get_test_images() | |
| if test_images: | |
| test_image_radio = gr.Radio( | |
| choices=[img[1] for img in test_images], | |
| label="Select test image", info="Click to load", | |
| ) | |
| test_image_radio.change( | |
| fn=lambda name: load_test_image( | |
| next((img[0] for img in test_images if img[1] == name), None)), | |
| inputs=[test_image_radio], outputs=[image_input], | |
| ) | |
| else: | |
| gr.Markdown("No test images found. Add images to the 'images' folder.") | |
| gr.Markdown("### Configuration") | |
| confidence_slider = gr.Slider(label="Confidence Threshold", minimum=0.0, | |
| maximum=1.0, value=0.2, step=0.01, | |
| info="Detection confidence level") | |
| iou_slider = gr.Slider(label="IOU Threshold (NMS)", minimum=0.0, maximum=1.0, | |
| value=0.2, step=0.01, | |
| info="Intersection over union threshold") | |
| imgsz_slider = gr.Slider(label="Image Size", minimum=320, maximum=2048, | |
| value=1280, step=32, info="Inference image resolution") | |
| predict_btn = gr.Button("Detect Objects", variant="primary", size="lg") | |
| with gr.Column(scale=1, min_width=400): | |
| gr.Markdown("### Results") | |
| image_output = gr.Image(label="Detections", type="pil", interactive=False) | |
| json_output = gr.File(label="Download labelme JSON") | |
| results_output = gr.Markdown(value="Detection results will appear here.") | |
| inputs = [image_input, confidence_slider, iou_slider, imgsz_slider] | |
| outputs = [image_output, json_output, results_output] | |
| predict_btn.click(fn=predict_image, inputs=inputs, outputs=outputs) | |
| image_input.change(fn=predict_image, inputs=inputs, outputs=outputs) | |
| confidence_slider.change(fn=predict_image, inputs=inputs, outputs=outputs) | |
| iou_slider.change(fn=predict_image, inputs=inputs, outputs=outputs) | |
| imgsz_slider.change(fn=predict_image, inputs=inputs, outputs=outputs) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, show_error=True, | |
| theme=dark_theme, css="footer {display: none !important;}") | |