import gradio as gr import requests import io import os from PIL import Image, ImageDraw, ImageFont from pathlib import Path API_URL = os.getenv("API_URL") API_KEY = os.getenv("API_KEY") IMAGE_FOLDER = "images" def get_test_images(): images = [] if os.path.exists(IMAGE_FOLDER): for file in sorted(Path(IMAGE_FOLDER).glob("*")): if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]: images.append((str(file), file.name)) return images def load_test_image(image_path): if image_path and os.path.exists(image_path): return Image.open(image_path) return None CLASS_NAMES = {0: "figure"} CLASS_COLORS = { 0: (255, 165, 0), } def _load_font(font_size): font_paths = [ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", "/System/Library/Fonts/Arial.ttf", "C:\\Windows\\Fonts\\arial.ttf", "arial.ttf", ] for path in font_paths: if os.path.exists(path): try: return ImageFont.truetype(path, font_size) except: continue return ImageFont.load_default() def draw_boxes_on_image(image, detections): if not detections: return image img_copy = image.copy() draw = ImageDraw.Draw(img_copy) img_width, img_height = img_copy.size min_dimension = min(img_width, img_height) font_size = max(int(min_dimension * 0.02), 32) line_width = max(int(min_dimension * 0.008), 3) label_font = _load_font(font_size) for detection in detections: confidence = detection.get("confidence", 0) class_id = detection.get("class", 0) box = detection.get("box", {}) color = CLASS_COLORS.get(class_id, (255, 165, 0)) x1 = int(box.get("x1", 0)) y1 = int(box.get("y1", 0)) x2 = int(box.get("x2", 0)) y2 = int(box.get("y2", 0)) if x1 > 0 and y1 > 0 and x2 > x1 and y2 > y1: draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width) label = f"Figure {confidence:.1%}" bbox = draw.textbbox((0, 0), label, font=label_font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] center_x = (x1 + x2) / 2 label_x = int(center_x - text_width / 2) label_y = max(0, y1 - text_height - 5) if label_x < 0: label_x = 2 if label_x + text_width > img_width: label_x = img_width - text_width - 2 bg_padding = 4 bg_box = [ label_x - bg_padding, label_y - bg_padding, label_x + text_width + bg_padding, label_y + text_height + bg_padding ] draw.rectangle(bg_box, outline=color, fill=(0, 0, 0)) draw.text((label_x, label_y), label, font=label_font, fill=color) return img_copy def predict_image(image, confidence, iou, imgsz): if image is None: return None, "#### Please upload an image to begin detection" try: img_bytes = io.BytesIO() image.save(img_bytes, format='JPEG') img_bytes.seek(0) params = { "conf": confidence, "iou": iou, "imgsz": imgsz } headers = {"Authorization": f"Bearer {API_KEY}"} files = {"file": ("image.jpg", img_bytes, "image/jpeg")} response = requests.post(API_URL, headers=headers, data=params, files=files, timeout=30) response.raise_for_status() result = response.json() formatted_result = format_results(result) detections = [] if "images" in result and len(result["images"]) > 0: detections = result["images"][0].get("results", []) image_with_boxes = draw_boxes_on_image(image, detections) return image_with_boxes, formatted_result except requests.exceptions.Timeout: return None, "#### Error: Request timeout. Please try again." except requests.exceptions.ConnectionError: return None, "#### Error: Unable to connect to detection service. Please check API configuration." except requests.exceptions.HTTPError as e: return None, f"#### Error: API returned status {e.response.status_code}" except Exception as e: return None, f"#### Error: {str(e)}" def format_results(result): if isinstance(result, dict): output = "## Detection Results\n\n" if "images" in result and len(result["images"]) > 0: img_data = result["images"][0] shape = img_data.get("shape", []) detections = img_data.get("results", []) output += f"**Image Size:** {shape[0]} x {shape[1]} (W x H)\n" output += f"**Detections Found:** {len(detections)}\n\n" speed = img_data.get("speed", {}) if speed: output += "\n### Performance Metrics\n" output += "| Metric | Time (ms) |\n" output += "|--------|----------|\n" output += f"| Preprocess | {speed.get('preprocess', 'N/A')} |\n" output += f"| Inference | {speed.get('inference', 'N/A')} |\n" output += f"| Postprocess | {speed.get('postprocess', 'N/A')} |\n" if detections: output += "### Detected Objects\n" output += "| Label | Class | Confidence |\n" output += "|-------|-------|------------|\n" for det in detections: name = det.get("name", "Unknown") class_id = det.get("class", "N/A") conf = det.get("confidence", 0) output += f"| {name} | {class_id} | {conf:.2%} |\n" return output return str(result) dark_theme = gr.themes.Monochrome( primary_hue="slate", secondary_hue="slate", ).set( body_text_color="#e0e0e0", background_fill_primary="#0f0f0f", background_fill_secondary="#1a1a1a", ) with gr.Blocks( title="Figure Detection", theme=dark_theme, css=""" footer {display: none !important;} .gradio-container {border-radius: 12px;} .gr-card {border-radius: 12px;} .block {border-radius: 12px;} .form {border-radius: 12px;} button {border-radius: 12px;} .gr-button {border-radius: 12px;} #imageModal { display: none; position: fixed; z-index: 10000; left: 0; top: 0; width: 100%; height: 100%; background-color: rgba(0, 0, 0, 0.9); animation: fadeIn 0.3s; } @keyframes fadeIn { from {opacity: 0;} to {opacity: 1;} } #modalImage { position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); max-width: 95%; max-height: 95%; object-fit: contain; touch-action: pinch-zoom; cursor: zoom-out; } .modal-open { overflow: hidden; } .closeBtn { position: absolute; top: 20px; right: 30px; font-size: 40px; font-weight: bold; color: white; cursor: pointer; z-index: 10001; } .closeBtn:hover { color: #bbb; } """ ) as demo: with gr.Column(): gr.Markdown(""" # Figure Detection Detect figures in your documents. Upload an image and adjust parameters to detect figures with custom inference settings. """) with gr.Row(): with gr.Column(scale=1, min_width=400): gr.Markdown("### Input") image_input = gr.Image( label="Image", type="pil", sources=["upload"], interactive=True ) test_images = get_test_images() if test_images: test_image_radio = gr.Radio( choices=[img[1] for img in test_images], label="Select test image", info="Click to load" ) test_image_radio.change( fn=lambda name: load_test_image(next((img[0] for img in test_images if img[1] == name), None)), inputs=[test_image_radio], outputs=[image_input] ) else: gr.Markdown("No test images found. Add images to the 'images' folder.") gr.Markdown("### Configuration") confidence_slider = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01, info="Detection confidence level" ) iou_slider = gr.Slider( label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.7, step=0.01, info="Intersection over union threshold" ) imgsz_slider = gr.Slider( label="Image Size", minimum=320, maximum=1280, value=640, step=32, info="Inference image resolution" ) predict_btn = gr.Button( "Detect Objects", variant="primary", size="lg", scale=1 ) with gr.Column(scale=1, min_width=400): gr.Markdown("### Results") image_output = gr.Image( label="Detections (Click to fullscreen)", type="pil", interactive=False, scale=1 ) results_output = gr.Markdown( value="Detection results will appear here.", label="Detection Results" ) gr.HTML("""
× Fullscreen Detection
""") predict_btn.click( fn=predict_image, inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], outputs=[image_output, results_output] ) image_input.change( fn=predict_image, inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], outputs=[image_output, results_output] ) confidence_slider.change( fn=predict_image, inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], outputs=[image_output, results_output] ) iou_slider.change( fn=predict_image, inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], outputs=[image_output, results_output] ) imgsz_slider.change( fn=predict_image, inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], outputs=[image_output, results_output] ) if __name__ == "__main__": demo.launch(share=False, show_error=True)