graph_detection / app.py
vichetkao's picture
Update app.py
5e54b1d verified
Raw
History Blame Contribute Delete
16.2 kB
import gradio as gr
import requests
import io
import os
from PIL import Image, ImageDraw, ImageFont
from pathlib import Path
API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")
IMAGE_FOLDER = "images"
def get_test_images():
images = []
if os.path.exists(IMAGE_FOLDER):
for file in sorted(Path(IMAGE_FOLDER).glob("*")):
if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]:
images.append((str(file), file.name))
return images
def load_test_image(image_path):
if image_path and os.path.exists(image_path):
return Image.open(image_path)
return None
CLASS_NAMES = {0: "figure"}
CLASS_COLORS = {
0: (255, 165, 0),
}
def _load_font(font_size):
font_paths = [
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
"/System/Library/Fonts/Arial.ttf",
"C:\\Windows\\Fonts\\arial.ttf",
"arial.ttf",
]
for path in font_paths:
if os.path.exists(path):
try:
return ImageFont.truetype(path, font_size)
except:
continue
return ImageFont.load_default()
def draw_boxes_on_image(image, detections):
if not detections:
return image
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
img_width, img_height = img_copy.size
min_dimension = min(img_width, img_height)
font_size = max(int(min_dimension * 0.02), 32)
line_width = max(int(min_dimension * 0.008), 3)
label_font = _load_font(font_size)
for detection in detections:
confidence = detection.get("confidence", 0)
class_id = detection.get("class", 0)
box = detection.get("box", {})
color = CLASS_COLORS.get(class_id, (255, 165, 0))
x1 = int(box.get("x1", 0))
y1 = int(box.get("y1", 0))
x2 = int(box.get("x2", 0))
y2 = int(box.get("y2", 0))
if x1 > 0 and y1 > 0 and x2 > x1 and y2 > y1:
draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width)
label = f"Figure {confidence:.1%}"
bbox = draw.textbbox((0, 0), label, font=label_font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
center_x = (x1 + x2) / 2
label_x = int(center_x - text_width / 2)
label_y = max(0, y1 - text_height - 5)
if label_x < 0:
label_x = 2
if label_x + text_width > img_width:
label_x = img_width - text_width - 2
bg_padding = 4
bg_box = [
label_x - bg_padding,
label_y - bg_padding,
label_x + text_width + bg_padding,
label_y + text_height + bg_padding
]
draw.rectangle(bg_box, outline=color, fill=(0, 0, 0))
draw.text((label_x, label_y), label, font=label_font, fill=color)
return img_copy
def predict_image(image, confidence, iou, imgsz):
if image is None:
return None, "#### Please upload an image to begin detection"
try:
img_bytes = io.BytesIO()
image.save(img_bytes, format='JPEG')
img_bytes.seek(0)
params = {
"conf": confidence,
"iou": iou,
"imgsz": imgsz
}
headers = {"Authorization": f"Bearer {API_KEY}"}
files = {"file": ("image.jpg", img_bytes, "image/jpeg")}
response = requests.post(API_URL, headers=headers, data=params, files=files, timeout=30)
response.raise_for_status()
result = response.json()
formatted_result = format_results(result)
detections = []
if "images" in result and len(result["images"]) > 0:
detections = result["images"][0].get("results", [])
image_with_boxes = draw_boxes_on_image(image, detections)
return image_with_boxes, formatted_result
except requests.exceptions.Timeout:
return None, "#### Error: Request timeout. Please try again."
except requests.exceptions.ConnectionError:
return None, "#### Error: Unable to connect to detection service. Please check API configuration."
except requests.exceptions.HTTPError as e:
return None, f"#### Error: API returned status {e.response.status_code}"
except Exception as e:
return None, f"#### Error: {str(e)}"
def format_results(result):
if isinstance(result, dict):
output = "## Detection Results\n\n"
if "images" in result and len(result["images"]) > 0:
img_data = result["images"][0]
shape = img_data.get("shape", [])
detections = img_data.get("results", [])
output += f"**Image Size:** {shape[0]} x {shape[1]} (W x H)\n"
output += f"**Detections Found:** {len(detections)}\n\n"
speed = img_data.get("speed", {})
if speed:
output += "\n### Performance Metrics\n"
output += "| Metric | Time (ms) |\n"
output += "|--------|----------|\n"
output += f"| Preprocess | {speed.get('preprocess', 'N/A')} |\n"
output += f"| Inference | {speed.get('inference', 'N/A')} |\n"
output += f"| Postprocess | {speed.get('postprocess', 'N/A')} |\n"
if detections:
output += "### Detected Objects\n"
output += "| Label | Class | Confidence |\n"
output += "|-------|-------|------------|\n"
for det in detections:
name = det.get("name", "Unknown")
class_id = det.get("class", "N/A")
conf = det.get("confidence", 0)
output += f"| {name} | {class_id} | {conf:.2%} |\n"
return output
return str(result)
dark_theme = gr.themes.Monochrome(
primary_hue="slate",
secondary_hue="slate",
).set(
body_text_color="#e0e0e0",
background_fill_primary="#0f0f0f",
background_fill_secondary="#1a1a1a",
)
with gr.Blocks(
title="Figure Detection",
theme=dark_theme,
css="""
footer {display: none !important;}
.gradio-container {border-radius: 12px;}
.gr-card {border-radius: 12px;}
.block {border-radius: 12px;}
.form {border-radius: 12px;}
button {border-radius: 12px;}
.gr-button {border-radius: 12px;}
#imageModal {
display: none;
position: fixed;
z-index: 10000;
left: 0;
top: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.9);
animation: fadeIn 0.3s;
}
@keyframes fadeIn {
from {opacity: 0;}
to {opacity: 1;}
}
#modalImage {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
max-width: 95%;
max-height: 95%;
object-fit: contain;
touch-action: pinch-zoom;
cursor: zoom-out;
}
.modal-open {
overflow: hidden;
}
.closeBtn {
position: absolute;
top: 20px;
right: 30px;
font-size: 40px;
font-weight: bold;
color: white;
cursor: pointer;
z-index: 10001;
}
.closeBtn:hover {
color: #bbb;
}
"""
) as demo:
with gr.Column():
gr.Markdown("""
# Figure Detection
Detect figures in your documents. Upload an image and adjust parameters to detect figures with custom inference settings.
""")
with gr.Row():
with gr.Column(scale=1, min_width=400):
gr.Markdown("### Input")
image_input = gr.Image(
label="Image",
type="pil",
sources=["upload"],
interactive=True
)
test_images = get_test_images()
if test_images:
test_image_radio = gr.Radio(
choices=[img[1] for img in test_images],
label="Select test image",
info="Click to load"
)
test_image_radio.change(
fn=lambda name: load_test_image(next((img[0] for img in test_images if img[1] == name), None)),
inputs=[test_image_radio],
outputs=[image_input]
)
else:
gr.Markdown("No test images found. Add images to the 'images' folder.")
gr.Markdown("### Configuration")
confidence_slider = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.01,
info="Detection confidence level"
)
iou_slider = gr.Slider(
label="IOU Threshold",
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.01,
info="Intersection over union threshold"
)
imgsz_slider = gr.Slider(
label="Image Size",
minimum=320,
maximum=1280,
value=640,
step=32,
info="Inference image resolution"
)
predict_btn = gr.Button(
"Detect Objects",
variant="primary",
size="lg",
scale=1
)
with gr.Column(scale=1, min_width=400):
gr.Markdown("### Results")
image_output = gr.Image(
label="Detections (Click to fullscreen)",
type="pil",
interactive=False,
scale=1
)
results_output = gr.Markdown(
value="Detection results will appear here.",
label="Detection Results"
)
gr.HTML("""
<div id="imageModal">
<span class="closeBtn">&times;</span>
<img id="modalImage" src="" alt="Fullscreen Detection">
</div>
<script>
const modal = document.getElementById('imageModal');
const modalImg = document.getElementById('modalImage');
const closeBtn = document.querySelector('.closeBtn');
let touchStartX = 0;
let touchStartY = 0;
let scale = 1;
const observeImageChanges = () => {
const imageContainer = document.querySelector('[data-testid="image"]') ||
document.querySelector('img[alt="Image"]');
if (imageContainer) {
const images = imageContainer.querySelectorAll('img');
images.forEach(img => {
if (img.src && !img.hasClickListener) {
img.style.cursor = 'pointer';
img.addEventListener('click', (e) => {
if (e.target.src && !e.target.src.includes('data:image/svg')) {
modalImg.src = e.target.src;
modal.style.display = 'block';
document.body.classList.add('modal-open');
scale = 1;
modalImg.style.transform = 'translate(-50%, -50%) scale(1)';
}
});
img.hasClickListener = true;
}
});
}
};
setInterval(observeImageChanges, 500);
observeImageChanges();
modal.addEventListener('click', (e) => {
if (e.target === modal) {
modal.style.display = 'none';
document.body.classList.remove('modal-open');
scale = 1;
}
});
closeBtn.addEventListener('click', () => {
modal.style.display = 'none';
document.body.classList.remove('modal-open');
scale = 1;
});
document.addEventListener('keydown', (e) => {
if (e.key === 'Escape' && modal.style.display === 'block') {
modal.style.display = 'none';
document.body.classList.remove('modal-open');
scale = 1;
}
});
let lastDistance = 0;
modalImg.addEventListener('touchstart', (e) => {
if (e.touches.length === 2) {
const dx = e.touches[0].clientX - e.touches[1].clientX;
const dy = e.touches[0].clientY - e.touches[1].clientY;
lastDistance = Math.sqrt(dx * dx + dy * dy);
}
touchStartX = e.touches[0].clientX;
touchStartY = e.touches[0].clientY;
});
modalImg.addEventListener('touchmove', (e) => {
if (e.touches.length === 2) {
const dx = e.touches[0].clientX - e.touches[1].clientX;
const dy = e.touches[0].clientY - e.touches[1].clientY;
const distance = Math.sqrt(dx * dx + dy * dy);
const scaleChange = distance / lastDistance;
scale = Math.max(1, Math.min(scale * scaleChange, 4));
modalImg.style.transform = `translate(-50%, -50%) scale(${scale})`;
lastDistance = distance;
}
});
modalImg.addEventListener('touchend', () => {
lastDistance = 0;
});
</script>
""")
predict_btn.click(
fn=predict_image,
inputs=[image_input, confidence_slider, iou_slider, imgsz_slider],
outputs=[image_output, results_output]
)
image_input.change(
fn=predict_image,
inputs=[image_input, confidence_slider, iou_slider, imgsz_slider],
outputs=[image_output, results_output]
)
confidence_slider.change(
fn=predict_image,
inputs=[image_input, confidence_slider, iou_slider, imgsz_slider],
outputs=[image_output, results_output]
)
iou_slider.change(
fn=predict_image,
inputs=[image_input, confidence_slider, iou_slider, imgsz_slider],
outputs=[image_output, results_output]
)
imgsz_slider.change(
fn=predict_image,
inputs=[image_input, confidence_slider, iou_slider, imgsz_slider],
outputs=[image_output, results_output]
)
if __name__ == "__main__":
demo.launch(share=False, show_error=True)