walidhadri's picture
Remove unsupported show_api arg
41d5ab0
#!/usr/bin/env python3
"""
Gradio demo for table detection with a YOLO model.
Designed for Hugging Face Spaces deployment.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional, Tuple
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
from ultralytics import YOLO
APP_TITLE = "TableDetect-YOLO26"
APP_SUBTITLE = "Fast and robust table detection for document images."
HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "walidhadri/table-detection-yolo26n")
HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "yolo26n-tablebank.pt")
HF_EXAMPLES_REPO = os.environ.get("HF_EXAMPLES_REPO", HF_MODEL_REPO)
HF_EXAMPLES_PATH = os.environ.get("HF_EXAMPLES_PATH", "assets/example")
DEFAULT_MODEL_PATH = Path("yolo26n.pt")
MODEL_PATH = Path(os.environ.get("YOLO_WEIGHTS", DEFAULT_MODEL_PATH))
def _resolve_weights_path() -> Path:
if MODEL_PATH.exists():
return MODEL_PATH
cache_dir = Path("models")
snapshot_path = Path(
snapshot_download(
repo_id=HF_MODEL_REPO,
allow_patterns=[HF_MODEL_FILENAME],
local_dir=str(cache_dir),
local_dir_use_symlinks=False,
)
)
downloaded = snapshot_path / HF_MODEL_FILENAME
if not downloaded.exists():
raise FileNotFoundError(
f"Downloaded weights not found at {downloaded}. "
"Check HF_MODEL_REPO/HF_MODEL_FILENAME."
)
return downloaded
def _load_model() -> YOLO:
weights_path = _resolve_weights_path()
return YOLO(str(weights_path))
MODEL = _load_model()
def _plot_result(result) -> Image.Image:
plot_bgr = result.plot()
plot_rgb = plot_bgr[:, :, ::-1]
return Image.fromarray(plot_rgb)
def predict(image: Image.Image, conf: float, iou: float, max_det: int) -> Image.Image:
if image is None:
return None
results = MODEL.predict(
source=image,
conf=conf,
iou=iou,
max_det=max_det,
verbose=False,
)
return _plot_result(results[0])
def _ensure_example_assets() -> list[Path]:
assets_dir = Path(HF_EXAMPLES_PATH)
if assets_dir.exists():
return sorted(assets_dir.glob("*.jpg"))
try:
snapshot_download(
repo_id=HF_EXAMPLES_REPO,
allow_patterns=[f"{HF_EXAMPLES_PATH}/*.jpg"],
local_dir=".",
local_dir_use_symlinks=False,
)
except Exception:
return []
if assets_dir.exists():
return sorted(assets_dir.glob("*.jpg"))
return []
def build_examples() -> list:
candidates = _ensure_example_assets()[:5]
return [[str(p)] for p in candidates]
CSS = """
.hero {
background: linear-gradient(135deg, #0b1f3a 0%, #0a5bc9 50%, #0d79ff 100%);
padding: 18px 22px;
border-radius: 14px;
color: #ffffff;
text-align: center;
margin-bottom: 14px;
box-shadow: 0 10px 24px rgba(0,0,0,0.25);
}
.hero h1 { font-size: 28px; margin: 0 0 6px 0; letter-spacing: 0.5px; }
.hero p { margin: 0; opacity: 0.9; }
.panel {
background: #1b1f27;
border-radius: 14px;
padding: 10px;
border: 1px solid #2b313d;
}
.controls label { font-weight: 600; }
.gradio-container {
background: #0f1115 !important;
color: #e6e6e6;
}
.gradio-container .container {
max-width: 1600px !important;
margin: 0 auto !important;
width: 100% !important;
}
#examples .gallery-item {
width: 120px !important;
height: 170px !important;
}
#examples .gallery-item img {
object-fit: contain !important;
}
"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(f"""
<div class="hero">
<h1>{APP_TITLE}</h1>
<p>{APP_SUBTITLE}</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
input_image = gr.Image(type="pil", label="Input Image", height=560)
with gr.Row():
clear_btn = gr.Button("Clear")
detect_btn = gr.Button("Detect", variant="primary")
with gr.Accordion("Detection Settings", open=True):
conf = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Confidence Threshold")
iou = gr.Slider(0.1, 0.9, value=0.5, step=0.01, label="NMS IoU Threshold")
max_det = gr.Slider(1, 300, value=100, step=1, label="Max Detections")
examples = build_examples()
if examples:
gr.Examples(
examples=examples,
inputs=[input_image],
label="Examples",
elem_id="examples",
examples_per_page=12,
)
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
output_image = gr.Image(type="pil", label="Predict Result", height=560)
detect_btn.click(
fn=predict,
inputs=[input_image, conf, iou, max_det],
outputs=[output_image],
)
clear_btn.click(fn=lambda: (None, None), inputs=[], outputs=[input_image, output_image])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0")