Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import io | |
| import os | |
| from PIL import Image, ImageDraw, ImageFont | |
| from pathlib import Path | |
| API_URL = os.getenv("API_URL") | |
| API_KEY = os.getenv("API_KEY") | |
| IMAGE_FOLDER = "images" | |
| def get_test_images(): | |
| images = [] | |
| if os.path.exists(IMAGE_FOLDER): | |
| for file in sorted(Path(IMAGE_FOLDER).glob("*")): | |
| if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]: | |
| images.append((str(file), file.name)) | |
| return images | |
| def load_test_image(image_path): | |
| if image_path and os.path.exists(image_path): | |
| return Image.open(image_path) | |
| return None | |
| CLASS_NAMES = {0: "figure"} | |
| CLASS_COLORS = { | |
| 0: (255, 165, 0), | |
| } | |
| def _load_font(font_size): | |
| font_paths = [ | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", | |
| "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", | |
| "/System/Library/Fonts/Arial.ttf", | |
| "C:\\Windows\\Fonts\\arial.ttf", | |
| "arial.ttf", | |
| ] | |
| for path in font_paths: | |
| if os.path.exists(path): | |
| try: | |
| return ImageFont.truetype(path, font_size) | |
| except: | |
| continue | |
| return ImageFont.load_default() | |
| def draw_boxes_on_image(image, detections): | |
| if not detections: | |
| return image | |
| img_copy = image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| img_width, img_height = img_copy.size | |
| min_dimension = min(img_width, img_height) | |
| font_size = max(int(min_dimension * 0.02), 32) | |
| line_width = max(int(min_dimension * 0.008), 3) | |
| label_font = _load_font(font_size) | |
| for detection in detections: | |
| confidence = detection.get("confidence", 0) | |
| class_id = detection.get("class", 0) | |
| box = detection.get("box", {}) | |
| color = CLASS_COLORS.get(class_id, (255, 165, 0)) | |
| x1 = int(box.get("x1", 0)) | |
| y1 = int(box.get("y1", 0)) | |
| x2 = int(box.get("x2", 0)) | |
| y2 = int(box.get("y2", 0)) | |
| if x1 > 0 and y1 > 0 and x2 > x1 and y2 > y1: | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width) | |
| label = f"Figure {confidence:.1%}" | |
| bbox = draw.textbbox((0, 0), label, font=label_font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| center_x = (x1 + x2) / 2 | |
| label_x = int(center_x - text_width / 2) | |
| label_y = max(0, y1 - text_height - 5) | |
| if label_x < 0: | |
| label_x = 2 | |
| if label_x + text_width > img_width: | |
| label_x = img_width - text_width - 2 | |
| bg_padding = 4 | |
| bg_box = [ | |
| label_x - bg_padding, | |
| label_y - bg_padding, | |
| label_x + text_width + bg_padding, | |
| label_y + text_height + bg_padding | |
| ] | |
| draw.rectangle(bg_box, outline=color, fill=(0, 0, 0)) | |
| draw.text((label_x, label_y), label, font=label_font, fill=color) | |
| return img_copy | |
| def predict_image(image, confidence, iou, imgsz): | |
| if image is None: | |
| return None, "#### Please upload an image to begin detection" | |
| try: | |
| img_bytes = io.BytesIO() | |
| image.save(img_bytes, format='JPEG') | |
| img_bytes.seek(0) | |
| params = { | |
| "conf": confidence, | |
| "iou": iou, | |
| "imgsz": imgsz | |
| } | |
| headers = {"Authorization": f"Bearer {API_KEY}"} | |
| files = {"file": ("image.jpg", img_bytes, "image/jpeg")} | |
| response = requests.post(API_URL, headers=headers, data=params, files=files, timeout=30) | |
| response.raise_for_status() | |
| result = response.json() | |
| formatted_result = format_results(result) | |
| detections = [] | |
| if "images" in result and len(result["images"]) > 0: | |
| detections = result["images"][0].get("results", []) | |
| image_with_boxes = draw_boxes_on_image(image, detections) | |
| return image_with_boxes, formatted_result | |
| except requests.exceptions.Timeout: | |
| return None, "#### Error: Request timeout. Please try again." | |
| except requests.exceptions.ConnectionError: | |
| return None, "#### Error: Unable to connect to detection service. Please check API configuration." | |
| except requests.exceptions.HTTPError as e: | |
| return None, f"#### Error: API returned status {e.response.status_code}" | |
| except Exception as e: | |
| return None, f"#### Error: {str(e)}" | |
| def format_results(result): | |
| if isinstance(result, dict): | |
| output = "## Detection Results\n\n" | |
| if "images" in result and len(result["images"]) > 0: | |
| img_data = result["images"][0] | |
| shape = img_data.get("shape", []) | |
| detections = img_data.get("results", []) | |
| output += f"**Image Size:** {shape[0]} x {shape[1]} (W x H)\n" | |
| output += f"**Detections Found:** {len(detections)}\n\n" | |
| speed = img_data.get("speed", {}) | |
| if speed: | |
| output += "\n### Performance Metrics\n" | |
| output += "| Metric | Time (ms) |\n" | |
| output += "|--------|----------|\n" | |
| output += f"| Preprocess | {speed.get('preprocess', 'N/A')} |\n" | |
| output += f"| Inference | {speed.get('inference', 'N/A')} |\n" | |
| output += f"| Postprocess | {speed.get('postprocess', 'N/A')} |\n" | |
| if detections: | |
| output += "### Detected Objects\n" | |
| output += "| Label | Class | Confidence |\n" | |
| output += "|-------|-------|------------|\n" | |
| for det in detections: | |
| name = det.get("name", "Unknown") | |
| class_id = det.get("class", "N/A") | |
| conf = det.get("confidence", 0) | |
| output += f"| {name} | {class_id} | {conf:.2%} |\n" | |
| return output | |
| return str(result) | |
| dark_theme = gr.themes.Monochrome( | |
| primary_hue="slate", | |
| secondary_hue="slate", | |
| ).set( | |
| body_text_color="#e0e0e0", | |
| background_fill_primary="#0f0f0f", | |
| background_fill_secondary="#1a1a1a", | |
| ) | |
| with gr.Blocks( | |
| title="Figure Detection", | |
| theme=dark_theme, | |
| css=""" | |
| footer {display: none !important;} | |
| .gradio-container {border-radius: 12px;} | |
| .gr-card {border-radius: 12px;} | |
| .block {border-radius: 12px;} | |
| .form {border-radius: 12px;} | |
| button {border-radius: 12px;} | |
| .gr-button {border-radius: 12px;} | |
| #imageModal { | |
| display: none; | |
| position: fixed; | |
| z-index: 10000; | |
| left: 0; | |
| top: 0; | |
| width: 100%; | |
| height: 100%; | |
| background-color: rgba(0, 0, 0, 0.9); | |
| animation: fadeIn 0.3s; | |
| } | |
| @keyframes fadeIn { | |
| from {opacity: 0;} | |
| to {opacity: 1;} | |
| } | |
| #modalImage { | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| max-width: 95%; | |
| max-height: 95%; | |
| object-fit: contain; | |
| touch-action: pinch-zoom; | |
| cursor: zoom-out; | |
| } | |
| .modal-open { | |
| overflow: hidden; | |
| } | |
| .closeBtn { | |
| position: absolute; | |
| top: 20px; | |
| right: 30px; | |
| font-size: 40px; | |
| font-weight: bold; | |
| color: white; | |
| cursor: pointer; | |
| z-index: 10001; | |
| } | |
| .closeBtn:hover { | |
| color: #bbb; | |
| } | |
| """ | |
| ) as demo: | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| # Figure Detection | |
| Detect figures in your documents. Upload an image and adjust parameters to detect figures with custom inference settings. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=400): | |
| gr.Markdown("### Input") | |
| image_input = gr.Image( | |
| label="Image", | |
| type="pil", | |
| sources=["upload"], | |
| interactive=True | |
| ) | |
| test_images = get_test_images() | |
| if test_images: | |
| test_image_radio = gr.Radio( | |
| choices=[img[1] for img in test_images], | |
| label="Select test image", | |
| info="Click to load" | |
| ) | |
| test_image_radio.change( | |
| fn=lambda name: load_test_image(next((img[0] for img in test_images if img[1] == name), None)), | |
| inputs=[test_image_radio], | |
| outputs=[image_input] | |
| ) | |
| else: | |
| gr.Markdown("No test images found. Add images to the 'images' folder.") | |
| gr.Markdown("### Configuration") | |
| confidence_slider = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.25, | |
| step=0.01, | |
| info="Detection confidence level" | |
| ) | |
| iou_slider = gr.Slider( | |
| label="IOU Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.01, | |
| info="Intersection over union threshold" | |
| ) | |
| imgsz_slider = gr.Slider( | |
| label="Image Size", | |
| minimum=320, | |
| maximum=1280, | |
| value=640, | |
| step=32, | |
| info="Inference image resolution" | |
| ) | |
| predict_btn = gr.Button( | |
| "Detect Objects", | |
| variant="primary", | |
| size="lg", | |
| scale=1 | |
| ) | |
| with gr.Column(scale=1, min_width=400): | |
| gr.Markdown("### Results") | |
| image_output = gr.Image( | |
| label="Detections (Click to fullscreen)", | |
| type="pil", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| results_output = gr.Markdown( | |
| value="Detection results will appear here.", | |
| label="Detection Results" | |
| ) | |
| gr.HTML(""" | |
| <div id="imageModal"> | |
| <span class="closeBtn">×</span> | |
| <img id="modalImage" src="" alt="Fullscreen Detection"> | |
| </div> | |
| <script> | |
| const modal = document.getElementById('imageModal'); | |
| const modalImg = document.getElementById('modalImage'); | |
| const closeBtn = document.querySelector('.closeBtn'); | |
| let touchStartX = 0; | |
| let touchStartY = 0; | |
| let scale = 1; | |
| const observeImageChanges = () => { | |
| const imageContainer = document.querySelector('[data-testid="image"]') || | |
| document.querySelector('img[alt="Image"]'); | |
| if (imageContainer) { | |
| const images = imageContainer.querySelectorAll('img'); | |
| images.forEach(img => { | |
| if (img.src && !img.hasClickListener) { | |
| img.style.cursor = 'pointer'; | |
| img.addEventListener('click', (e) => { | |
| if (e.target.src && !e.target.src.includes('data:image/svg')) { | |
| modalImg.src = e.target.src; | |
| modal.style.display = 'block'; | |
| document.body.classList.add('modal-open'); | |
| scale = 1; | |
| modalImg.style.transform = 'translate(-50%, -50%) scale(1)'; | |
| } | |
| }); | |
| img.hasClickListener = true; | |
| } | |
| }); | |
| } | |
| }; | |
| setInterval(observeImageChanges, 500); | |
| observeImageChanges(); | |
| modal.addEventListener('click', (e) => { | |
| if (e.target === modal) { | |
| modal.style.display = 'none'; | |
| document.body.classList.remove('modal-open'); | |
| scale = 1; | |
| } | |
| }); | |
| closeBtn.addEventListener('click', () => { | |
| modal.style.display = 'none'; | |
| document.body.classList.remove('modal-open'); | |
| scale = 1; | |
| }); | |
| document.addEventListener('keydown', (e) => { | |
| if (e.key === 'Escape' && modal.style.display === 'block') { | |
| modal.style.display = 'none'; | |
| document.body.classList.remove('modal-open'); | |
| scale = 1; | |
| } | |
| }); | |
| let lastDistance = 0; | |
| modalImg.addEventListener('touchstart', (e) => { | |
| if (e.touches.length === 2) { | |
| const dx = e.touches[0].clientX - e.touches[1].clientX; | |
| const dy = e.touches[0].clientY - e.touches[1].clientY; | |
| lastDistance = Math.sqrt(dx * dx + dy * dy); | |
| } | |
| touchStartX = e.touches[0].clientX; | |
| touchStartY = e.touches[0].clientY; | |
| }); | |
| modalImg.addEventListener('touchmove', (e) => { | |
| if (e.touches.length === 2) { | |
| const dx = e.touches[0].clientX - e.touches[1].clientX; | |
| const dy = e.touches[0].clientY - e.touches[1].clientY; | |
| const distance = Math.sqrt(dx * dx + dy * dy); | |
| const scaleChange = distance / lastDistance; | |
| scale = Math.max(1, Math.min(scale * scaleChange, 4)); | |
| modalImg.style.transform = `translate(-50%, -50%) scale(${scale})`; | |
| lastDistance = distance; | |
| } | |
| }); | |
| modalImg.addEventListener('touchend', () => { | |
| lastDistance = 0; | |
| }); | |
| </script> | |
| """) | |
| predict_btn.click( | |
| fn=predict_image, | |
| inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], | |
| outputs=[image_output, results_output] | |
| ) | |
| image_input.change( | |
| fn=predict_image, | |
| inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], | |
| outputs=[image_output, results_output] | |
| ) | |
| confidence_slider.change( | |
| fn=predict_image, | |
| inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], | |
| outputs=[image_output, results_output] | |
| ) | |
| iou_slider.change( | |
| fn=predict_image, | |
| inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], | |
| outputs=[image_output, results_output] | |
| ) | |
| imgsz_slider.change( | |
| fn=predict_image, | |
| inputs=[image_input, confidence_slider, iou_slider, imgsz_slider], | |
| outputs=[image_output, results_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, show_error=True) |