Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Gradio demo for table detection with a YOLO model. | |
| Designed for Hugging Face Spaces deployment. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from ultralytics import YOLO | |
| APP_TITLE = "TableDetect-YOLO26" | |
| APP_SUBTITLE = "Fast and robust table detection for document images." | |
| HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "walidhadri/table-detection-yolo26n") | |
| HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "yolo26n-tablebank.pt") | |
| HF_EXAMPLES_REPO = os.environ.get("HF_EXAMPLES_REPO", HF_MODEL_REPO) | |
| HF_EXAMPLES_PATH = os.environ.get("HF_EXAMPLES_PATH", "assets/example") | |
| DEFAULT_MODEL_PATH = Path("yolo26n.pt") | |
| MODEL_PATH = Path(os.environ.get("YOLO_WEIGHTS", DEFAULT_MODEL_PATH)) | |
| def _resolve_weights_path() -> Path: | |
| if MODEL_PATH.exists(): | |
| return MODEL_PATH | |
| cache_dir = Path("models") | |
| snapshot_path = Path( | |
| snapshot_download( | |
| repo_id=HF_MODEL_REPO, | |
| allow_patterns=[HF_MODEL_FILENAME], | |
| local_dir=str(cache_dir), | |
| local_dir_use_symlinks=False, | |
| ) | |
| ) | |
| downloaded = snapshot_path / HF_MODEL_FILENAME | |
| if not downloaded.exists(): | |
| raise FileNotFoundError( | |
| f"Downloaded weights not found at {downloaded}. " | |
| "Check HF_MODEL_REPO/HF_MODEL_FILENAME." | |
| ) | |
| return downloaded | |
| def _load_model() -> YOLO: | |
| weights_path = _resolve_weights_path() | |
| return YOLO(str(weights_path)) | |
| MODEL = _load_model() | |
| def _plot_result(result) -> Image.Image: | |
| plot_bgr = result.plot() | |
| plot_rgb = plot_bgr[:, :, ::-1] | |
| return Image.fromarray(plot_rgb) | |
| def predict(image: Image.Image, conf: float, iou: float, max_det: int) -> Image.Image: | |
| if image is None: | |
| return None | |
| results = MODEL.predict( | |
| source=image, | |
| conf=conf, | |
| iou=iou, | |
| max_det=max_det, | |
| verbose=False, | |
| ) | |
| return _plot_result(results[0]) | |
| def _ensure_example_assets() -> list[Path]: | |
| assets_dir = Path(HF_EXAMPLES_PATH) | |
| if assets_dir.exists(): | |
| return sorted(assets_dir.glob("*.jpg")) | |
| try: | |
| snapshot_download( | |
| repo_id=HF_EXAMPLES_REPO, | |
| allow_patterns=[f"{HF_EXAMPLES_PATH}/*.jpg"], | |
| local_dir=".", | |
| local_dir_use_symlinks=False, | |
| ) | |
| except Exception: | |
| return [] | |
| if assets_dir.exists(): | |
| return sorted(assets_dir.glob("*.jpg")) | |
| return [] | |
| def build_examples() -> list: | |
| candidates = _ensure_example_assets()[:5] | |
| return [[str(p)] for p in candidates] | |
| CSS = """ | |
| .hero { | |
| background: linear-gradient(135deg, #0b1f3a 0%, #0a5bc9 50%, #0d79ff 100%); | |
| padding: 18px 22px; | |
| border-radius: 14px; | |
| color: #ffffff; | |
| text-align: center; | |
| margin-bottom: 14px; | |
| box-shadow: 0 10px 24px rgba(0,0,0,0.25); | |
| } | |
| .hero h1 { font-size: 28px; margin: 0 0 6px 0; letter-spacing: 0.5px; } | |
| .hero p { margin: 0; opacity: 0.9; } | |
| .panel { | |
| background: #1b1f27; | |
| border-radius: 14px; | |
| padding: 10px; | |
| border: 1px solid #2b313d; | |
| } | |
| .controls label { font-weight: 600; } | |
| .gradio-container { | |
| background: #0f1115 !important; | |
| color: #e6e6e6; | |
| } | |
| .gradio-container .container { | |
| max-width: 1600px !important; | |
| margin: 0 auto !important; | |
| width: 100% !important; | |
| } | |
| #examples .gallery-item { | |
| width: 120px !important; | |
| height: 170px !important; | |
| } | |
| #examples .gallery-item img { | |
| object-fit: contain !important; | |
| } | |
| """ | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(f""" | |
| <div class="hero"> | |
| <h1>{APP_TITLE}</h1> | |
| <p>{APP_SUBTITLE}</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| input_image = gr.Image(type="pil", label="Input Image", height=560) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| detect_btn = gr.Button("Detect", variant="primary") | |
| with gr.Accordion("Detection Settings", open=True): | |
| conf = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Confidence Threshold") | |
| iou = gr.Slider(0.1, 0.9, value=0.5, step=0.01, label="NMS IoU Threshold") | |
| max_det = gr.Slider(1, 300, value=100, step=1, label="Max Detections") | |
| examples = build_examples() | |
| if examples: | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image], | |
| label="Examples", | |
| elem_id="examples", | |
| examples_per_page=12, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| output_image = gr.Image(type="pil", label="Predict Result", height=560) | |
| detect_btn.click( | |
| fn=predict, | |
| inputs=[input_image, conf, iou, max_det], | |
| outputs=[output_image], | |
| ) | |
| clear_btn.click(fn=lambda: (None, None), inputs=[], outputs=[input_image, output_image]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0") | |