WolseyTheCat's picture
inital commit
3d82aef
import gradio as gr
import cv2
import numpy as np
import pytesseract
import base64, json, io
from PIL import Image
# HTML template that loads Fabric.js and creates an interactive canvas.
html_template = """
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script src="https://cdnjs.cloudflare.com/ajax/libs/fabric.js/4.6.0/fabric.min.js"></script>
<style>
canvas { border: 1px solid #ccc; }
</style>
</head>
<body>
<canvas id="c" width="600" height="400"></canvas>
<script>
// Parse JSON data from Python.
var data = {data_json};
// Initialize Fabric.js canvas.
var canvas = new fabric.Canvas('c');
// Load the image as canvas background.
var imgObj = new Image();
imgObj.src = "data:image/png;base64," + data.image_data;
imgObj.onload = function() {
var bg = new fabric.Image(imgObj);
bg.selectable = false;
// Scale background to canvas dimensions.
bg.scaleToWidth(canvas.width);
bg.scaleToHeight(canvas.height);
canvas.setBackgroundImage(bg, canvas.renderAll.bind(canvas));
};
// Add detected objects to the canvas.
data.objects.forEach(function(obj) {
if(obj.type === "text") {
var textObj = new fabric.IText(obj.text, {
left: obj.x,
top: obj.y,
fontSize: 20,
fill: 'black'
});
canvas.add(textObj);
} else if(obj.type === "image") {
var rect = new fabric.Rect({
left: obj.x,
top: obj.y,
width: obj.width,
height: obj.height,
fill: 'rgba(0, 0, 255, 0.3)'
});
canvas.add(rect);
}
});
</script>
</body>
</html>
"""
def generate_html(image):
# If the PNG has transparency, composite it onto a white background.
if image.shape[2] == 4:
alpha = image[:, :, 3] / 255.0
image_rgb = image[:, :, :3]
white_bg = np.ones_like(image_rgb, dtype=np.uint8) * 255
image = np.uint8(image_rgb * alpha[..., None] + white_bg * (1 - alpha[..., None]))
# Convert the image (numpy array) to a base64-encoded PNG.
pil_image = Image.fromarray(image)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
# ------------------- TEXT DETECTION -------------------
text_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
detected_texts = []
n_boxes = len(text_data['level'])
for i in range(n_boxes):
try:
conf = int(text_data['conf'][i])
except:
conf = 0
text_content = text_data['text'][i].strip()
if conf > 60 and text_content:
x = int(text_data['left'][i])
y = int(text_data['top'][i])
w = int(text_data['width'][i])
h = int(text_data['height'][i])
detected_texts.append({
'type': 'text',
'text': text_content,
'x': x,
'y': y,
'width': w,
'height': h,
'confidence': conf
})
# ---------------- NON-TEXT OBJECT DETECTION ----------------
# Convert image to grayscale and threshold to detect non-white regions.
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
detected_images = []
# Helper function to compute Intersection over Union (IoU) for overlap testing.
def iou(box1, box2):
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
inter_x = max(0, min(x1+w1, x2+w2) - max(x1, x2))
inter_y = max(0, min(y1+h1, y2+h2) - max(y1, y2))
inter_area = inter_x * inter_y
area1 = w1 * h1
area2 = w2 * h2
union = area1 + area2 - inter_area
return inter_area / union if union != 0 else 0
# Prepare text bounding boxes for filtering.
text_boxes = [(obj['x'], obj['y'], obj['width'], obj['height']) for obj in detected_texts]
image_id = 0
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
if w < 10 or h < 10:
continue
# Skip if the contour significantly overlaps with a detected text box.
overlap = any(iou((x, y, w, h), tb) > 0.5 for tb in text_boxes)
if not overlap:
detected_images.append({
'type': 'image',
'id': image_id,
'x': x,
'y': y,
'width': w,
'height': h
})
image_id += 1
# Combine text and non-text objects.
objects = detected_texts + detected_images
result = {
"image_data": base64_image,
"objects": objects
}
# Insert the JSON data into the HTML template.
json_data = json.dumps(result)
html_code = html_template.replace("{data_json}", json_data)
return html_code
# Create the Gradio interface.
with gr.Blocks() as demo:
gr.Markdown("## Interactive Image Editor")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload PNG Image", source="upload", type="numpy")
process_button = gr.Button("Process Image")
with gr.Column():
html_output = gr.HTML(label="Interactive Editor")
process_button.click(fn=generate_html, inputs=input_image, outputs=html_output)
demo.launch()