Spaces:
Sleeping
Sleeping
Commit Β·
f7a075a
1
Parent(s): a9565d5
Added app.py, detection module, and model
Browse files- app.py +121 -0
- visualization.py +1395 -0
app.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import shutil
|
| 4 |
+
import traceback
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from fastapi import FastAPI, File, UploadFile
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
|
| 10 |
+
from visualization import process_wireframe
|
| 11 |
+
|
| 12 |
+
# -----------------------------------------------------------------------------
|
| 13 |
+
# FASTAPI (for Firebase / programmatic access)
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
api = FastAPI()
|
| 16 |
+
|
| 17 |
+
TEMP_DIR = "./temp"
|
| 18 |
+
OUTPUT_DIR = "./output"
|
| 19 |
+
|
| 20 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
| 21 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@api.get("/")
|
| 25 |
+
def health_check():
|
| 26 |
+
return {"status": "ok"}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@api.post("/process-wireframe")
|
| 30 |
+
async def process_wireframe_api(image: UploadFile = File(...)):
|
| 31 |
+
file_id = str(uuid.uuid4())
|
| 32 |
+
temp_path = os.path.join(TEMP_DIR, f"{file_id}_{image.filename}")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
with open(temp_path, "wb") as f:
|
| 36 |
+
shutil.copyfileobj(image.file, f)
|
| 37 |
+
|
| 38 |
+
results = process_wireframe(
|
| 39 |
+
image_path=temp_path,
|
| 40 |
+
save_json=True,
|
| 41 |
+
save_html=True,
|
| 42 |
+
show_visualization=False
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if not results:
|
| 46 |
+
return JSONResponse(
|
| 47 |
+
status_code=400,
|
| 48 |
+
content={"error": "No elements detected"}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"success": True,
|
| 53 |
+
"json_path": results.get("json_path"),
|
| 54 |
+
"html_path": results.get("html_path"),
|
| 55 |
+
"total_elements": len(results["normalized_elements"])
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
traceback.print_exc()
|
| 60 |
+
return JSONResponse(
|
| 61 |
+
status_code=500,
|
| 62 |
+
content={"error": str(e)}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
finally:
|
| 66 |
+
if os.path.exists(temp_path):
|
| 67 |
+
os.remove(temp_path)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# -----------------------------------------------------------------------------
|
| 71 |
+
# GRADIO (for Hugging Face UI)
|
| 72 |
+
# -----------------------------------------------------------------------------
|
| 73 |
+
def gradio_process(image):
|
| 74 |
+
"""
|
| 75 |
+
Gradio passes a PIL Image.
|
| 76 |
+
We save it temporarily and reuse the SAME pipeline.
|
| 77 |
+
"""
|
| 78 |
+
temp_path = f"{TEMP_DIR}/{uuid.uuid4()}.png"
|
| 79 |
+
image.save(temp_path)
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
results = process_wireframe(
|
| 83 |
+
image_path=temp_path,
|
| 84 |
+
save_json=True,
|
| 85 |
+
save_html=True,
|
| 86 |
+
show_visualization=False
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if not results:
|
| 90 |
+
return "No elements detected", None
|
| 91 |
+
|
| 92 |
+
return (
|
| 93 |
+
f"Detected {len(results['normalized_elements'])} elements",
|
| 94 |
+
results.get("json_path")
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
traceback.print_exc()
|
| 99 |
+
return f"Error: {str(e)}", None
|
| 100 |
+
|
| 101 |
+
finally:
|
| 102 |
+
if os.path.exists(temp_path):
|
| 103 |
+
os.remove(temp_path)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
demo = gr.Interface(
|
| 107 |
+
fn=gradio_process,
|
| 108 |
+
inputs=gr.Image(type="pil", label="Upload Wireframe"),
|
| 109 |
+
outputs=[
|
| 110 |
+
gr.Textbox(label="Status"),
|
| 111 |
+
gr.File(label="Normalized JSON Output")
|
| 112 |
+
],
|
| 113 |
+
title="Wireframe Layout Normalizer",
|
| 114 |
+
description="Upload a wireframe image to extract and normalize UI layout"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# -----------------------------------------------------------------------------
|
| 118 |
+
# ENTRY POINT (THIS IS IMPORTANT)
|
| 119 |
+
# -----------------------------------------------------------------------------
|
| 120 |
+
app = gr.mount_gradio_app(api, demo, path="/")
|
| 121 |
+
|
visualization.py
ADDED
|
@@ -0,0 +1,1395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import matplotlib.patches as patches
|
| 5 |
+
from PIL import Image, ImageOps
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Tuple, Dict, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ============================================================================
|
| 14 |
+
# CUSTOM LOSS CLASS (Required for model loading)
|
| 15 |
+
# ============================================================================
|
| 16 |
+
@tf.keras.utils.register_keras_serializable()
|
| 17 |
+
class LossCalculation(tf.keras.losses.Loss):
|
| 18 |
+
"""Custom loss function for wireframe detection."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, num_classes=7, lambda_coord=5.0, lambda_noobj=0.5,
|
| 21 |
+
name='loss_calculation', reduction='sum_over_batch_size', **kwargs):
|
| 22 |
+
super().__init__(name=name, reduction=reduction)
|
| 23 |
+
self.num_classes = num_classes
|
| 24 |
+
self.lambda_coord = lambda_coord
|
| 25 |
+
self.lambda_noobj = lambda_noobj
|
| 26 |
+
|
| 27 |
+
def call(self, y_true, y_pred):
|
| 28 |
+
obj_true = y_true[..., 0]
|
| 29 |
+
box_true = y_true[..., 1:5]
|
| 30 |
+
cls_true = y_true[..., 5:]
|
| 31 |
+
|
| 32 |
+
obj_pred_logits = y_pred[..., 0]
|
| 33 |
+
box_pred = y_pred[..., 1:5]
|
| 34 |
+
cls_pred_logits = y_pred[..., 5:]
|
| 35 |
+
|
| 36 |
+
obj_mask = tf.cast(obj_true > 0.5, tf.float32)
|
| 37 |
+
noobj_mask = 1.0 - obj_mask
|
| 38 |
+
num_pos = tf.maximum(tf.reduce_sum(obj_mask), 1.0)
|
| 39 |
+
|
| 40 |
+
obj_loss_pos = obj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
|
| 41 |
+
labels=obj_true, logits=obj_pred_logits)
|
| 42 |
+
obj_loss_neg = noobj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
|
| 43 |
+
labels=obj_true, logits=obj_pred_logits)
|
| 44 |
+
obj_loss = (tf.reduce_sum(obj_loss_pos) + self.lambda_noobj * tf.reduce_sum(obj_loss_neg)) / tf.cast(
|
| 45 |
+
tf.size(obj_true), tf.float32)
|
| 46 |
+
|
| 47 |
+
xy_pred = tf.nn.sigmoid(box_pred[..., 0:2])
|
| 48 |
+
wh_pred = tf.nn.sigmoid(box_pred[..., 2:4])
|
| 49 |
+
xy_true = box_true[..., 0:2]
|
| 50 |
+
wh_true = box_true[..., 2:4]
|
| 51 |
+
|
| 52 |
+
xy_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(xy_true - xy_pred)) / num_pos
|
| 53 |
+
wh_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(wh_true - wh_pred)) / num_pos
|
| 54 |
+
box_loss = self.lambda_coord * (xy_loss + wh_loss)
|
| 55 |
+
|
| 56 |
+
cls_loss = tf.reduce_sum(obj_mask * tf.nn.softmax_cross_entropy_with_logits(
|
| 57 |
+
labels=cls_true, logits=cls_pred_logits)) / num_pos
|
| 58 |
+
|
| 59 |
+
total_loss = obj_loss + box_loss + cls_loss
|
| 60 |
+
return tf.clip_by_value(total_loss, 0.0, 100.0)
|
| 61 |
+
|
| 62 |
+
def _smooth_l1_loss(self, x, beta=1.0):
|
| 63 |
+
abs_x = tf.abs(x)
|
| 64 |
+
return tf.where(abs_x < beta, 0.5 * x * x / beta, abs_x - 0.5 * beta)
|
| 65 |
+
|
| 66 |
+
def get_config(self):
|
| 67 |
+
config = super().get_config()
|
| 68 |
+
config.update({
|
| 69 |
+
'num_classes': self.num_classes,
|
| 70 |
+
'lambda_coord': self.lambda_coord,
|
| 71 |
+
'lambda_noobj': self.lambda_noobj,
|
| 72 |
+
})
|
| 73 |
+
return config
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_config(cls, config):
|
| 77 |
+
return cls(**config)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ============================================================================
|
| 81 |
+
# CONFIGURATION - UPDATED FOR BETTER PRECISION
|
| 82 |
+
# ============================================================================
|
| 83 |
+
MODEL_PATH = "./wireframe_detection_model_best_700.keras"
|
| 84 |
+
OUTPUT_DIR = "./output/"
|
| 85 |
+
CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"]
|
| 86 |
+
|
| 87 |
+
IMG_SIZE = 416
|
| 88 |
+
CONF_THRESHOLD = 0.1
|
| 89 |
+
IOU_THRESHOLD = 0.1
|
| 90 |
+
|
| 91 |
+
# Layout Configuration - INCREASED GRID DENSITY
|
| 92 |
+
GRID_COLUMNS = 24 # Doubled from 12 for finer precision
|
| 93 |
+
ALIGNMENT_THRESHOLD = 10 # Reduced from 15 for tighter alignment
|
| 94 |
+
SIZE_CLUSTERING_THRESHOLD = 15 # Reduced from 20 for better size grouping
|
| 95 |
+
|
| 96 |
+
# Standard sizes for each element type (relative units) - UPDATED FOR SMALLER BUTTONS/CHECKBOXES
|
| 97 |
+
STANDARD_SIZES = {
|
| 98 |
+
'button': {'width': 2, 'height': 1}, # Smaller button (was 2x1, now in finer grid)
|
| 99 |
+
'checkbox': {'width': 1, 'height': 1}, # Keep small checkbox
|
| 100 |
+
'textfield': {'width': 5, 'height': 1}, # Adjusted for new grid
|
| 101 |
+
'text': {'width': 3, 'height': 1}, # Adjusted
|
| 102 |
+
'paragraph': {'width': 8, 'height': 2}, # Adjusted
|
| 103 |
+
'image': {'width': 4, 'height': 4}, # Adjusted
|
| 104 |
+
'navbar': {'width': 24, 'height': 1} # Full width in new grid
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
model = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ============================================================================
|
| 111 |
+
# DATA STRUCTURES
|
| 112 |
+
# ============================================================================
|
| 113 |
+
@dataclass
|
| 114 |
+
class Element:
|
| 115 |
+
"""Represents a detected UI element."""
|
| 116 |
+
label: str
|
| 117 |
+
score: float
|
| 118 |
+
bbox: List[float] # [x1, y1, x2, y2]
|
| 119 |
+
width: float = 0
|
| 120 |
+
height: float = 0
|
| 121 |
+
center_x: float = 0
|
| 122 |
+
center_y: float = 0
|
| 123 |
+
|
| 124 |
+
def __post_init__(self):
|
| 125 |
+
self.width = self.bbox[2] - self.bbox[0]
|
| 126 |
+
self.height = self.bbox[3] - self.bbox[1]
|
| 127 |
+
self.center_x = (self.bbox[0] + self.bbox[2]) / 2
|
| 128 |
+
self.center_y = (self.bbox[1] + self.bbox[3]) / 2
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass
|
| 132 |
+
class NormalizedElement:
|
| 133 |
+
"""Represents a normalized UI element."""
|
| 134 |
+
original: Element
|
| 135 |
+
normalized_bbox: List[float]
|
| 136 |
+
grid_position: Dict
|
| 137 |
+
size_category: str
|
| 138 |
+
alignment_group: Optional[int] = None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ============================================================================
|
| 142 |
+
# PREDICTION EXTRACTION
|
| 143 |
+
# ============================================================================
|
| 144 |
+
def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
|
| 145 |
+
"""Extract predictions from the model."""
|
| 146 |
+
global model
|
| 147 |
+
if model is None:
|
| 148 |
+
raise ValueError("Model not loaded. Please load the model first.")
|
| 149 |
+
|
| 150 |
+
# Load and preprocess image
|
| 151 |
+
pil_img = Image.open(image_path).convert("RGB")
|
| 152 |
+
pil_img = ImageOps.exif_transpose(pil_img)
|
| 153 |
+
orig_w, orig_h = pil_img.size
|
| 154 |
+
resized_img = pil_img.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
|
| 155 |
+
img_array = np.array(resized_img, dtype=np.float32) / 255.0
|
| 156 |
+
input_tensor = np.expand_dims(img_array, axis=0)
|
| 157 |
+
|
| 158 |
+
# Get predictions
|
| 159 |
+
pred_grid = model.predict(input_tensor, verbose=0)[0]
|
| 160 |
+
raw_boxes = []
|
| 161 |
+
S = pred_grid.shape[0]
|
| 162 |
+
cell_size = 1.0 / S
|
| 163 |
+
|
| 164 |
+
for row in range(S):
|
| 165 |
+
for col in range(S):
|
| 166 |
+
obj_score = float(tf.nn.sigmoid(pred_grid[row, col, 0]))
|
| 167 |
+
if obj_score < CONF_THRESHOLD:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
x_offset = float(tf.nn.sigmoid(pred_grid[row, col, 1]))
|
| 171 |
+
y_offset = float(tf.nn.sigmoid(pred_grid[row, col, 2]))
|
| 172 |
+
width = float(tf.nn.sigmoid(pred_grid[row, col, 3]))
|
| 173 |
+
height = float(tf.nn.sigmoid(pred_grid[row, col, 4]))
|
| 174 |
+
|
| 175 |
+
class_logits = pred_grid[row, col, 5:]
|
| 176 |
+
class_probs = tf.nn.softmax(class_logits).numpy()
|
| 177 |
+
class_id = int(np.argmax(class_probs))
|
| 178 |
+
class_conf = float(class_probs[class_id])
|
| 179 |
+
final_score = obj_score * class_conf
|
| 180 |
+
|
| 181 |
+
if final_score < CONF_THRESHOLD:
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
center_x = (col + x_offset) * cell_size
|
| 185 |
+
center_y = (row + y_offset) * cell_size
|
| 186 |
+
x1 = (center_x - width / 2) * orig_w
|
| 187 |
+
y1 = (center_y - height / 2) * orig_h
|
| 188 |
+
x2 = (center_x + width / 2) * orig_w
|
| 189 |
+
y2 = (center_y + height / 2) * orig_h
|
| 190 |
+
|
| 191 |
+
if x2 > x1 and y2 > y1:
|
| 192 |
+
raw_boxes.append((class_id, final_score, x1, y1, x2, y2))
|
| 193 |
+
|
| 194 |
+
# Apply NMS per class
|
| 195 |
+
elements = []
|
| 196 |
+
for class_id in range(len(CLASS_NAMES)):
|
| 197 |
+
class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id]
|
| 198 |
+
if not class_boxes:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
scores = [b[0] for b in class_boxes]
|
| 202 |
+
boxes_xyxy = [[b[1], b[2], b[3], b[4]] for b in class_boxes]
|
| 203 |
+
|
| 204 |
+
selected_indices = tf.image.non_max_suppression(
|
| 205 |
+
boxes=boxes_xyxy,
|
| 206 |
+
scores=scores,
|
| 207 |
+
max_output_size=50,
|
| 208 |
+
iou_threshold=IOU_THRESHOLD,
|
| 209 |
+
score_threshold=CONF_THRESHOLD
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
for idx in selected_indices.numpy():
|
| 213 |
+
score, x1, y1, x2, y2 = class_boxes[idx]
|
| 214 |
+
elements.append(Element(
|
| 215 |
+
label=CLASS_NAMES[class_id],
|
| 216 |
+
score=float(score),
|
| 217 |
+
bbox=[float(x1), float(y1), float(x2), float(y2)]
|
| 218 |
+
))
|
| 219 |
+
|
| 220 |
+
return pil_img, elements
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ============================================================================
|
| 224 |
+
# ALIGNMENT DETECTION
|
| 225 |
+
# ============================================================================
|
| 226 |
+
class AlignmentDetector:
|
| 227 |
+
"""Detects alignment relationships between elements."""
|
| 228 |
+
|
| 229 |
+
def __init__(self, elements: List[Element], threshold: float = ALIGNMENT_THRESHOLD):
|
| 230 |
+
self.elements = elements
|
| 231 |
+
self.threshold = threshold
|
| 232 |
+
|
| 233 |
+
def detect_horizontal_alignments(self) -> List[List[Element]]:
|
| 234 |
+
"""Group elements that are horizontally aligned (same Y position)."""
|
| 235 |
+
if not self.elements:
|
| 236 |
+
return []
|
| 237 |
+
|
| 238 |
+
sorted_elements = sorted(self.elements, key=lambda e: e.center_y)
|
| 239 |
+
groups = []
|
| 240 |
+
current_group = [sorted_elements[0]]
|
| 241 |
+
|
| 242 |
+
for elem in sorted_elements[1:]:
|
| 243 |
+
avg_y = sum(e.center_y for e in current_group) / len(current_group)
|
| 244 |
+
if abs(elem.center_y - avg_y) <= self.threshold:
|
| 245 |
+
current_group.append(elem)
|
| 246 |
+
else:
|
| 247 |
+
if len(current_group) > 1:
|
| 248 |
+
current_group.sort(key=lambda e: e.center_x)
|
| 249 |
+
groups.append(current_group)
|
| 250 |
+
current_group = [elem]
|
| 251 |
+
|
| 252 |
+
if len(current_group) > 1:
|
| 253 |
+
current_group.sort(key=lambda e: e.center_x)
|
| 254 |
+
groups.append(current_group)
|
| 255 |
+
|
| 256 |
+
return groups
|
| 257 |
+
|
| 258 |
+
def detect_vertical_alignments(self) -> List[List[Element]]:
|
| 259 |
+
"""Group elements that are vertically aligned (same X position)."""
|
| 260 |
+
if not self.elements:
|
| 261 |
+
return []
|
| 262 |
+
|
| 263 |
+
sorted_elements = sorted(self.elements, key=lambda e: e.center_x)
|
| 264 |
+
groups = []
|
| 265 |
+
current_group = [sorted_elements[0]]
|
| 266 |
+
|
| 267 |
+
for elem in sorted_elements[1:]:
|
| 268 |
+
avg_x = sum(e.center_x for e in current_group) / len(current_group)
|
| 269 |
+
if abs(elem.center_x - avg_x) <= self.threshold:
|
| 270 |
+
current_group.append(elem)
|
| 271 |
+
else:
|
| 272 |
+
if len(current_group) > 1:
|
| 273 |
+
current_group.sort(key=lambda e: e.center_y)
|
| 274 |
+
groups.append(current_group)
|
| 275 |
+
current_group = [elem]
|
| 276 |
+
|
| 277 |
+
if len(current_group) > 1:
|
| 278 |
+
current_group.sort(key=lambda e: e.center_y)
|
| 279 |
+
groups.append(current_group)
|
| 280 |
+
|
| 281 |
+
return groups
|
| 282 |
+
|
| 283 |
+
def detect_edge_alignments(self) -> Dict[str, List[List[Element]]]:
|
| 284 |
+
"""Detect elements with aligned edges (left, right, top, bottom)."""
|
| 285 |
+
alignments = {
|
| 286 |
+
'left': [],
|
| 287 |
+
'right': [],
|
| 288 |
+
'top': [],
|
| 289 |
+
'bottom': []
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if not self.elements:
|
| 293 |
+
return alignments
|
| 294 |
+
|
| 295 |
+
sorted_left = sorted(self.elements, key=lambda e: e.bbox[0])
|
| 296 |
+
alignments['left'] = self._cluster_by_value(sorted_left, lambda e: e.bbox[0])
|
| 297 |
+
|
| 298 |
+
sorted_right = sorted(self.elements, key=lambda e: e.bbox[2])
|
| 299 |
+
alignments['right'] = self._cluster_by_value(sorted_right, lambda e: e.bbox[2])
|
| 300 |
+
|
| 301 |
+
sorted_top = sorted(self.elements, key=lambda e: e.bbox[1])
|
| 302 |
+
alignments['top'] = self._cluster_by_value(sorted_top, lambda e: e.bbox[1])
|
| 303 |
+
|
| 304 |
+
sorted_bottom = sorted(self.elements, key=lambda e: e.bbox[3])
|
| 305 |
+
alignments['bottom'] = self._cluster_by_value(sorted_bottom, lambda e: e.bbox[3])
|
| 306 |
+
|
| 307 |
+
return alignments
|
| 308 |
+
|
| 309 |
+
def _cluster_by_value(self, elements: List[Element], value_func) -> List[List[Element]]:
|
| 310 |
+
"""Cluster elements by a value function within threshold."""
|
| 311 |
+
if not elements:
|
| 312 |
+
return []
|
| 313 |
+
|
| 314 |
+
groups = []
|
| 315 |
+
current_group = [elements[0]]
|
| 316 |
+
current_value = value_func(elements[0])
|
| 317 |
+
|
| 318 |
+
for elem in elements[1:]:
|
| 319 |
+
elem_value = value_func(elem)
|
| 320 |
+
if abs(elem_value - current_value) <= self.threshold:
|
| 321 |
+
current_group.append(elem)
|
| 322 |
+
current_value = (current_value * (len(current_group) - 1) + elem_value) / len(current_group)
|
| 323 |
+
else:
|
| 324 |
+
if len(current_group) > 1:
|
| 325 |
+
groups.append(current_group)
|
| 326 |
+
current_group = [elem]
|
| 327 |
+
current_value = elem_value
|
| 328 |
+
|
| 329 |
+
if len(current_group) > 1:
|
| 330 |
+
groups.append(current_group)
|
| 331 |
+
|
| 332 |
+
return groups
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# ============================================================================
|
| 336 |
+
# SIZE NORMALIZATION - UPDATED TO RESPECT ACTUAL SIZES MORE
|
| 337 |
+
# ============================================================================
|
| 338 |
+
class SizeNormalizer:
|
| 339 |
+
"""Normalizes element sizes based on type and clustering."""
|
| 340 |
+
|
| 341 |
+
def __init__(self, elements: List[Element], img_width: float, img_height: float):
|
| 342 |
+
self.elements = elements
|
| 343 |
+
self.img_width = img_width
|
| 344 |
+
self.img_height = img_height
|
| 345 |
+
self.size_clusters = {}
|
| 346 |
+
|
| 347 |
+
def cluster_sizes_by_type(self) -> Dict[str, List[List[Element]]]:
|
| 348 |
+
"""Cluster elements of same type by similar sizes."""
|
| 349 |
+
clusters_by_type = {}
|
| 350 |
+
|
| 351 |
+
for label in CLASS_NAMES:
|
| 352 |
+
type_elements = [e for e in self.elements if e.label == label]
|
| 353 |
+
if not type_elements:
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
width_clusters = self._cluster_by_dimension(type_elements, 'width')
|
| 357 |
+
final_clusters = []
|
| 358 |
+
for width_cluster in width_clusters:
|
| 359 |
+
height_clusters = self._cluster_by_dimension(width_cluster, 'height')
|
| 360 |
+
final_clusters.extend(height_clusters)
|
| 361 |
+
|
| 362 |
+
clusters_by_type[label] = final_clusters
|
| 363 |
+
|
| 364 |
+
return clusters_by_type
|
| 365 |
+
|
| 366 |
+
def _cluster_by_dimension(self, elements: List[Element], dimension: str) -> List[List[Element]]:
|
| 367 |
+
"""Cluster elements by width or height."""
|
| 368 |
+
if not elements:
|
| 369 |
+
return []
|
| 370 |
+
|
| 371 |
+
sorted_elements = sorted(elements, key=lambda e: getattr(e, dimension))
|
| 372 |
+
clusters = []
|
| 373 |
+
current_cluster = [sorted_elements[0]]
|
| 374 |
+
|
| 375 |
+
for elem in sorted_elements[1:]:
|
| 376 |
+
avg_dim = sum(getattr(e, dimension) for e in current_cluster) / len(current_cluster)
|
| 377 |
+
if abs(getattr(elem, dimension) - avg_dim) <= SIZE_CLUSTERING_THRESHOLD:
|
| 378 |
+
current_cluster.append(elem)
|
| 379 |
+
else:
|
| 380 |
+
clusters.append(current_cluster)
|
| 381 |
+
current_cluster = [elem]
|
| 382 |
+
|
| 383 |
+
clusters.append(current_cluster)
|
| 384 |
+
return clusters
|
| 385 |
+
|
| 386 |
+
def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]:
|
| 387 |
+
"""Get normalized size for an element based on its cluster - PRESERVES ACTUAL SIZE BETTER."""
|
| 388 |
+
# Use the actual detected size instead of aggressive averaging
|
| 389 |
+
# Only normalize if there's a significant cluster
|
| 390 |
+
if len(size_cluster) >= 3:
|
| 391 |
+
# Use median instead of mean to avoid outliers
|
| 392 |
+
widths = sorted([e.width for e in size_cluster])
|
| 393 |
+
heights = sorted([e.height for e in size_cluster])
|
| 394 |
+
median_width = widths[len(widths) // 2]
|
| 395 |
+
median_height = heights[len(heights) // 2]
|
| 396 |
+
|
| 397 |
+
# Only normalize if element is within 30% of median
|
| 398 |
+
if abs(element.width - median_width) / median_width < 0.3:
|
| 399 |
+
normalized_width = round(median_width)
|
| 400 |
+
else:
|
| 401 |
+
normalized_width = round(element.width)
|
| 402 |
+
|
| 403 |
+
if abs(element.height - median_height) / median_height < 0.3:
|
| 404 |
+
normalized_height = round(median_height)
|
| 405 |
+
else:
|
| 406 |
+
normalized_height = round(element.height)
|
| 407 |
+
else:
|
| 408 |
+
# Small cluster - keep original size
|
| 409 |
+
normalized_width = round(element.width)
|
| 410 |
+
normalized_height = round(element.height)
|
| 411 |
+
|
| 412 |
+
return normalized_width, normalized_height
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# ============================================================================
|
| 416 |
+
# GRID-BASED LAYOUT SYSTEM - UPDATED FOR FINER PRECISION
|
| 417 |
+
# ============================================================================
|
| 418 |
+
class GridLayoutSystem:
|
| 419 |
+
"""Grid-based layout system for precise positioning."""
|
| 420 |
+
|
| 421 |
+
def __init__(self, img_width: float, img_height: float, num_columns: int = GRID_COLUMNS):
|
| 422 |
+
self.img_width = img_width
|
| 423 |
+
self.img_height = img_height
|
| 424 |
+
self.num_columns = num_columns
|
| 425 |
+
|
| 426 |
+
cell_width = img_width / num_columns
|
| 427 |
+
self.num_rows = max(1, int(img_height / cell_width))
|
| 428 |
+
self.cell_width = img_width / num_columns
|
| 429 |
+
self.cell_height = img_height / self.num_rows
|
| 430 |
+
|
| 431 |
+
print(f"π Grid System: {self.num_columns} columns Γ {self.num_rows} rows")
|
| 432 |
+
print(f"π Cell size: {self.cell_width:.1f}px Γ {self.cell_height:.1f}px")
|
| 433 |
+
|
| 434 |
+
def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]:
|
| 435 |
+
"""Snap bounding box to grid - UPDATED TO PRESERVE ORIGINAL SIZE BETTER."""
|
| 436 |
+
x1, y1, x2, y2 = bbox
|
| 437 |
+
original_width = x2 - x1
|
| 438 |
+
original_height = y2 - y1
|
| 439 |
+
|
| 440 |
+
# Calculate center
|
| 441 |
+
center_x = (x1 + x2) / 2
|
| 442 |
+
center_y = (y1 + y2) / 2
|
| 443 |
+
|
| 444 |
+
# Find nearest grid cell for center
|
| 445 |
+
center_col = round(center_x / self.cell_width)
|
| 446 |
+
center_row = round(center_y / self.cell_height)
|
| 447 |
+
|
| 448 |
+
if preserve_size:
|
| 449 |
+
# Calculate span based on actual size (don't force to standard)
|
| 450 |
+
width_cells = max(1, round(original_width / self.cell_width))
|
| 451 |
+
height_cells = max(1, round(original_height / self.cell_height))
|
| 452 |
+
else:
|
| 453 |
+
# Use standard size
|
| 454 |
+
standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1})
|
| 455 |
+
width_cells = max(1, round(original_width / self.cell_width))
|
| 456 |
+
height_cells = max(1, round(original_height / self.cell_height))
|
| 457 |
+
|
| 458 |
+
# Only adjust to standard if very close
|
| 459 |
+
if abs(width_cells - standard['width']) <= 0.5:
|
| 460 |
+
width_cells = standard['width']
|
| 461 |
+
if abs(height_cells - standard['height']) <= 0.5:
|
| 462 |
+
height_cells = standard['height']
|
| 463 |
+
|
| 464 |
+
# Calculate start position (center the element)
|
| 465 |
+
start_col = center_col - width_cells // 2
|
| 466 |
+
start_row = center_row - height_cells // 2
|
| 467 |
+
|
| 468 |
+
# Clamp to grid bounds
|
| 469 |
+
start_col = max(0, min(start_col, self.num_columns - width_cells))
|
| 470 |
+
start_row = max(0, min(start_row, self.num_rows - height_cells))
|
| 471 |
+
|
| 472 |
+
# Convert back to pixels
|
| 473 |
+
snapped_x1 = start_col * self.cell_width
|
| 474 |
+
snapped_y1 = start_row * self.cell_height
|
| 475 |
+
snapped_x2 = (start_col + width_cells) * self.cell_width
|
| 476 |
+
snapped_y2 = (start_row + height_cells) * self.cell_height
|
| 477 |
+
|
| 478 |
+
return [snapped_x1, snapped_y1, snapped_x2, snapped_y2]
|
| 479 |
+
|
| 480 |
+
def get_grid_position(self, bbox: List[float]) -> Dict:
|
| 481 |
+
"""Get grid position information for a bounding box."""
|
| 482 |
+
x1, y1, x2, y2 = bbox
|
| 483 |
+
|
| 484 |
+
start_col = int(x1 / self.cell_width)
|
| 485 |
+
start_row = int(y1 / self.cell_height)
|
| 486 |
+
end_col = int(np.ceil(x2 / self.cell_width))
|
| 487 |
+
end_row = int(np.ceil(y2 / self.cell_height))
|
| 488 |
+
|
| 489 |
+
return {
|
| 490 |
+
'start_row': start_row,
|
| 491 |
+
'end_row': end_row,
|
| 492 |
+
'start_col': start_col,
|
| 493 |
+
'end_col': end_col,
|
| 494 |
+
'rowspan': end_row - start_row,
|
| 495 |
+
'colspan': end_col - start_col
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# ============================================================================
|
| 500 |
+
# OVERLAP DETECTION & RESOLUTION - UPDATED WITH BETTER STRATEGIES
|
| 501 |
+
# ============================================================================
|
| 502 |
+
class OverlapResolver:
|
| 503 |
+
"""Detects and resolves overlapping elements."""
|
| 504 |
+
|
| 505 |
+
def __init__(self, elements: List[Element], img_width: float, img_height: float):
|
| 506 |
+
self.elements = elements
|
| 507 |
+
self.img_width = img_width
|
| 508 |
+
self.img_height = img_height
|
| 509 |
+
self.overlap_threshold = 0.2 # Reduced from 0.3 - be more aggressive
|
| 510 |
+
|
| 511 |
+
def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
|
| 512 |
+
"""Compute Intersection over Union between two bounding boxes."""
|
| 513 |
+
x1 = max(bbox1[0], bbox2[0])
|
| 514 |
+
y1 = max(bbox1[1], bbox2[1])
|
| 515 |
+
x2 = min(bbox1[2], bbox2[2])
|
| 516 |
+
y2 = min(bbox1[3], bbox2[3])
|
| 517 |
+
|
| 518 |
+
if x2 <= x1 or y2 <= y1:
|
| 519 |
+
return 0.0
|
| 520 |
+
|
| 521 |
+
intersection = (x2 - x1) * (y2 - y1)
|
| 522 |
+
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
| 523 |
+
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
| 524 |
+
union = area1 + area2 - intersection
|
| 525 |
+
|
| 526 |
+
return intersection / union if union > 0 else 0.0
|
| 527 |
+
|
| 528 |
+
def compute_overlap_ratio(self, bbox1: List[float], bbox2: List[float]) -> Tuple[float, float]:
|
| 529 |
+
"""Compute what percentage of each box overlaps with the other."""
|
| 530 |
+
x1 = max(bbox1[0], bbox2[0])
|
| 531 |
+
y1 = max(bbox1[1], bbox2[1])
|
| 532 |
+
x2 = min(bbox1[2], bbox2[2])
|
| 533 |
+
y2 = min(bbox1[3], bbox2[3])
|
| 534 |
+
|
| 535 |
+
if x2 <= x1 or y2 <= y1:
|
| 536 |
+
return 0.0, 0.0
|
| 537 |
+
|
| 538 |
+
intersection = (x2 - x1) * (y2 - y1)
|
| 539 |
+
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
| 540 |
+
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
| 541 |
+
|
| 542 |
+
overlap_ratio1 = intersection / area1 if area1 > 0 else 0.0
|
| 543 |
+
overlap_ratio2 = intersection / area2 if area2 > 0 else 0.0
|
| 544 |
+
|
| 545 |
+
return overlap_ratio1, overlap_ratio2
|
| 546 |
+
|
| 547 |
+
def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]:
|
| 548 |
+
"""Resolve overlaps by adjusting element positions - IMPROVED ALGORITHM."""
|
| 549 |
+
print("\nπ Checking for overlaps...")
|
| 550 |
+
|
| 551 |
+
overlaps = []
|
| 552 |
+
for i in range(len(normalized_elements)):
|
| 553 |
+
for j in range(i + 1, len(normalized_elements)):
|
| 554 |
+
ne1 = normalized_elements[i]
|
| 555 |
+
ne2 = normalized_elements[j]
|
| 556 |
+
|
| 557 |
+
iou = self.compute_iou(ne1.normalized_bbox, ne2.normalized_bbox)
|
| 558 |
+
if iou > 0:
|
| 559 |
+
overlap1, overlap2 = self.compute_overlap_ratio(
|
| 560 |
+
ne1.normalized_bbox, ne2.normalized_bbox
|
| 561 |
+
)
|
| 562 |
+
max_overlap = max(overlap1, overlap2)
|
| 563 |
+
|
| 564 |
+
if max_overlap >= self.overlap_threshold:
|
| 565 |
+
overlaps.append({
|
| 566 |
+
'idx1': i,
|
| 567 |
+
'idx2': j,
|
| 568 |
+
'elem1': ne1,
|
| 569 |
+
'elem2': ne2,
|
| 570 |
+
'overlap': max_overlap,
|
| 571 |
+
'overlap1': overlap1,
|
| 572 |
+
'overlap2': overlap2,
|
| 573 |
+
'iou': iou
|
| 574 |
+
})
|
| 575 |
+
|
| 576 |
+
if not overlaps:
|
| 577 |
+
print("β
No significant overlaps detected")
|
| 578 |
+
return normalized_elements
|
| 579 |
+
|
| 580 |
+
print(f"β οΈ Found {len(overlaps)} overlapping element pairs")
|
| 581 |
+
|
| 582 |
+
# Sort by overlap severity
|
| 583 |
+
overlaps.sort(key=lambda x: x['overlap'], reverse=True)
|
| 584 |
+
|
| 585 |
+
elements_to_remove = set()
|
| 586 |
+
|
| 587 |
+
for overlap_info in overlaps:
|
| 588 |
+
idx1 = overlap_info['idx1']
|
| 589 |
+
idx2 = overlap_info['idx2']
|
| 590 |
+
|
| 591 |
+
if idx1 in elements_to_remove or idx2 in elements_to_remove:
|
| 592 |
+
continue
|
| 593 |
+
|
| 594 |
+
elem1 = overlap_info['elem1']
|
| 595 |
+
elem2 = overlap_info['elem2']
|
| 596 |
+
overlap_ratio = overlap_info['overlap']
|
| 597 |
+
|
| 598 |
+
# Strategy 1: Nearly complete overlap (>70%) - remove lower confidence
|
| 599 |
+
if overlap_ratio > 0.7:
|
| 600 |
+
if elem1.original.score < elem2.original.score:
|
| 601 |
+
elements_to_remove.add(idx1)
|
| 602 |
+
print(f" ποΈ Removing {elem1.original.label} (conf: {elem1.original.score:.2f}) - "
|
| 603 |
+
f"overlaps {overlap_ratio * 100:.1f}% with {elem2.original.label}")
|
| 604 |
+
else:
|
| 605 |
+
elements_to_remove.add(idx2)
|
| 606 |
+
print(f" ποΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - "
|
| 607 |
+
f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}")
|
| 608 |
+
|
| 609 |
+
# Strategy 2: Significant overlap (40-70%) - try to separate
|
| 610 |
+
elif overlap_ratio > 0.4:
|
| 611 |
+
self._try_separate_elements(elem1, elem2, overlap_info)
|
| 612 |
+
print(f" βοΈ Separating {elem1.original.label} and {elem2.original.label} "
|
| 613 |
+
f"(overlap: {overlap_ratio * 100:.1f}%)")
|
| 614 |
+
|
| 615 |
+
# Strategy 3: Moderate overlap (20-40%) - shrink slightly
|
| 616 |
+
else:
|
| 617 |
+
self._shrink_overlapping_edges(elem1, elem2, overlap_info)
|
| 618 |
+
print(f" π Shrinking {elem1.original.label} and {elem2.original.label} "
|
| 619 |
+
f"(overlap: {overlap_ratio * 100:.1f}%)")
|
| 620 |
+
|
| 621 |
+
if elements_to_remove:
|
| 622 |
+
normalized_elements = [
|
| 623 |
+
ne for i, ne in enumerate(normalized_elements)
|
| 624 |
+
if i not in elements_to_remove
|
| 625 |
+
]
|
| 626 |
+
print(f"β
Removed {len(elements_to_remove)} completely overlapping elements")
|
| 627 |
+
|
| 628 |
+
return normalized_elements
|
| 629 |
+
|
| 630 |
+
def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement,
|
| 631 |
+
overlap_info: Dict):
|
| 632 |
+
"""Try to separate two significantly overlapping elements - IMPROVED."""
|
| 633 |
+
bbox1 = elem1.normalized_bbox
|
| 634 |
+
bbox2 = elem2.normalized_bbox
|
| 635 |
+
|
| 636 |
+
# Calculate overlap dimensions
|
| 637 |
+
overlap_x1 = max(bbox1[0], bbox2[0])
|
| 638 |
+
overlap_y1 = max(bbox1[1], bbox2[1])
|
| 639 |
+
overlap_x2 = min(bbox1[2], bbox2[2])
|
| 640 |
+
overlap_y2 = min(bbox1[3], bbox2[3])
|
| 641 |
+
|
| 642 |
+
overlap_width = overlap_x2 - overlap_x1
|
| 643 |
+
overlap_height = overlap_y2 - overlap_y1
|
| 644 |
+
|
| 645 |
+
# Calculate centers
|
| 646 |
+
center1_x = (bbox1[0] + bbox1[2]) / 2
|
| 647 |
+
center1_y = (bbox1[1] + bbox1[3]) / 2
|
| 648 |
+
center2_x = (bbox2[0] + bbox2[2]) / 2
|
| 649 |
+
center2_y = (bbox2[1] + bbox2[3]) / 2
|
| 650 |
+
|
| 651 |
+
# Determine separation direction
|
| 652 |
+
dx = abs(center2_x - center1_x)
|
| 653 |
+
dy = abs(center2_y - center1_y)
|
| 654 |
+
|
| 655 |
+
# Add minimum gap
|
| 656 |
+
min_gap = 3 # pixels
|
| 657 |
+
|
| 658 |
+
if dx > dy:
|
| 659 |
+
# Separate horizontally
|
| 660 |
+
if center1_x < center2_x:
|
| 661 |
+
# elem1 is left of elem2
|
| 662 |
+
midpoint = (bbox1[2] + bbox2[0]) / 2
|
| 663 |
+
bbox1[2] = midpoint - min_gap
|
| 664 |
+
bbox2[0] = midpoint + min_gap
|
| 665 |
+
else:
|
| 666 |
+
# elem2 is left of elem1
|
| 667 |
+
midpoint = (bbox2[2] + bbox1[0]) / 2
|
| 668 |
+
bbox2[2] = midpoint - min_gap
|
| 669 |
+
bbox1[0] = midpoint + min_gap
|
| 670 |
+
else:
|
| 671 |
+
# Separate vertically
|
| 672 |
+
if center1_y < center2_y:
|
| 673 |
+
# elem1 is above elem2
|
| 674 |
+
midpoint = (bbox1[3] + bbox2[1]) / 2
|
| 675 |
+
bbox1[3] = midpoint - min_gap
|
| 676 |
+
bbox2[1] = midpoint + min_gap
|
| 677 |
+
else:
|
| 678 |
+
# elem2 is above elem1
|
| 679 |
+
midpoint = (bbox2[3] + bbox1[1]) / 2
|
| 680 |
+
bbox2[3] = midpoint - min_gap
|
| 681 |
+
bbox1[1] = midpoint + min_gap
|
| 682 |
+
|
| 683 |
+
# Ensure boxes remain valid
|
| 684 |
+
self._ensure_valid_bbox(bbox1)
|
| 685 |
+
self._ensure_valid_bbox(bbox2)
|
| 686 |
+
|
| 687 |
+
def _shrink_overlapping_edges(self, elem1: NormalizedElement, elem2: NormalizedElement,
|
| 688 |
+
overlap_info: Dict):
|
| 689 |
+
"""Shrink overlapping edges for moderate overlaps."""
|
| 690 |
+
bbox1 = elem1.normalized_bbox
|
| 691 |
+
bbox2 = elem2.normalized_bbox
|
| 692 |
+
|
| 693 |
+
# Calculate overlap region
|
| 694 |
+
overlap_x1 = max(bbox1[0], bbox2[0])
|
| 695 |
+
overlap_y1 = max(bbox1[1], bbox2[1])
|
| 696 |
+
overlap_x2 = min(bbox1[2], bbox2[2])
|
| 697 |
+
overlap_y2 = min(bbox1[3], bbox2[3])
|
| 698 |
+
|
| 699 |
+
overlap_width = overlap_x2 - overlap_x1
|
| 700 |
+
overlap_height = overlap_y2 - overlap_y1
|
| 701 |
+
|
| 702 |
+
# Shrink by 50% of overlap plus small gap
|
| 703 |
+
gap = 2 # pixels
|
| 704 |
+
|
| 705 |
+
if overlap_width > overlap_height:
|
| 706 |
+
# Horizontal overlap is larger
|
| 707 |
+
shrink = overlap_width / 2 + gap
|
| 708 |
+
if bbox1[0] < bbox2[0]:
|
| 709 |
+
bbox1[2] -= shrink
|
| 710 |
+
bbox2[0] += shrink
|
| 711 |
+
else:
|
| 712 |
+
bbox2[2] -= shrink
|
| 713 |
+
bbox1[0] += shrink
|
| 714 |
+
else:
|
| 715 |
+
# Vertical overlap is larger
|
| 716 |
+
shrink = overlap_height / 2 + gap
|
| 717 |
+
if bbox1[1] < bbox2[1]:
|
| 718 |
+
bbox1[3] -= shrink
|
| 719 |
+
bbox2[1] += shrink
|
| 720 |
+
else:
|
| 721 |
+
bbox2[3] -= shrink
|
| 722 |
+
bbox1[1] += shrink
|
| 723 |
+
|
| 724 |
+
self._ensure_valid_bbox(bbox1)
|
| 725 |
+
self._ensure_valid_bbox(bbox2)
|
| 726 |
+
|
| 727 |
+
def _ensure_valid_bbox(self, bbox: List[float]):
|
| 728 |
+
"""Ensure bounding box has minimum size and is within image bounds."""
|
| 729 |
+
min_size = 8 # Reduced minimum size
|
| 730 |
+
|
| 731 |
+
# Ensure minimum size
|
| 732 |
+
if bbox[2] - bbox[0] < min_size:
|
| 733 |
+
center_x = (bbox[0] + bbox[2]) / 2
|
| 734 |
+
bbox[0] = center_x - min_size / 2
|
| 735 |
+
bbox[2] = center_x + min_size / 2
|
| 736 |
+
|
| 737 |
+
if bbox[3] - bbox[1] < min_size:
|
| 738 |
+
center_y = (bbox[1] + bbox[3]) / 2
|
| 739 |
+
bbox[1] = center_y - min_size / 2
|
| 740 |
+
bbox[3] = center_y + min_size / 2
|
| 741 |
+
|
| 742 |
+
# Clamp to image bounds
|
| 743 |
+
bbox[0] = max(0, min(bbox[0], self.img_width))
|
| 744 |
+
bbox[1] = max(0, min(bbox[1], self.img_height))
|
| 745 |
+
bbox[2] = max(0, min(bbox[2], self.img_width))
|
| 746 |
+
bbox[3] = max(0, min(bbox[3], self.img_height))
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# ============================================================================
|
| 750 |
+
# MAIN NORMALIZATION ENGINE
|
| 751 |
+
# ============================================================================
|
| 752 |
+
class LayoutNormalizer:
|
| 753 |
+
"""Main engine for normalizing wireframe layout."""
|
| 754 |
+
|
| 755 |
+
def __init__(self, elements: List[Element], img_width: float, img_height: float):
|
| 756 |
+
self.elements = elements
|
| 757 |
+
self.img_width = img_width
|
| 758 |
+
self.img_height = img_height
|
| 759 |
+
self.grid = GridLayoutSystem(img_width, img_height)
|
| 760 |
+
self.alignment_detector = AlignmentDetector(elements)
|
| 761 |
+
self.size_normalizer = SizeNormalizer(elements, img_width, img_height)
|
| 762 |
+
|
| 763 |
+
def normalize_layout(self) -> List[NormalizedElement]:
|
| 764 |
+
"""Normalize all elements with proper sizing and alignment."""
|
| 765 |
+
print("\nπ§ Starting layout normalization...")
|
| 766 |
+
|
| 767 |
+
# Step 1: Detect alignments
|
| 768 |
+
h_alignments = self.alignment_detector.detect_horizontal_alignments()
|
| 769 |
+
v_alignments = self.alignment_detector.detect_vertical_alignments()
|
| 770 |
+
edge_alignments = self.alignment_detector.detect_edge_alignments()
|
| 771 |
+
|
| 772 |
+
print(f"β Found {len(h_alignments)} horizontal alignment groups")
|
| 773 |
+
print(f"β Found {len(v_alignments)} vertical alignment groups")
|
| 774 |
+
|
| 775 |
+
# Step 2: Cluster sizes by type
|
| 776 |
+
size_clusters = self.size_normalizer.cluster_sizes_by_type()
|
| 777 |
+
print(f"β Created size clusters for {len(size_clusters)} element types")
|
| 778 |
+
|
| 779 |
+
# Step 3: Create element-to-cluster mapping
|
| 780 |
+
element_to_cluster = {}
|
| 781 |
+
element_to_size_category = {}
|
| 782 |
+
for label, clusters in size_clusters.items():
|
| 783 |
+
for i, cluster in enumerate(clusters):
|
| 784 |
+
category = f"{label}_size_{i + 1}"
|
| 785 |
+
for elem in cluster:
|
| 786 |
+
element_to_cluster[id(elem)] = cluster
|
| 787 |
+
element_to_size_category[id(elem)] = category
|
| 788 |
+
|
| 789 |
+
# Step 4: Normalize each element
|
| 790 |
+
normalized_elements = []
|
| 791 |
+
|
| 792 |
+
for elem in self.elements:
|
| 793 |
+
# Get size cluster
|
| 794 |
+
cluster = element_to_cluster.get(id(elem), [elem])
|
| 795 |
+
size_category = element_to_size_category.get(id(elem), f"{elem.label}_default")
|
| 796 |
+
|
| 797 |
+
# Get normalized size
|
| 798 |
+
norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster)
|
| 799 |
+
|
| 800 |
+
# Create normalized bbox (centered on original)
|
| 801 |
+
center_x, center_y = elem.center_x, elem.center_y
|
| 802 |
+
norm_bbox = [
|
| 803 |
+
center_x - norm_width / 2,
|
| 804 |
+
center_y - norm_height / 2,
|
| 805 |
+
center_x + norm_width / 2,
|
| 806 |
+
center_y + norm_height / 2
|
| 807 |
+
]
|
| 808 |
+
|
| 809 |
+
# Snap to grid - preserve original size better
|
| 810 |
+
snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True)
|
| 811 |
+
grid_position = self.grid.get_grid_position(snapped_bbox)
|
| 812 |
+
|
| 813 |
+
normalized_elements.append(NormalizedElement(
|
| 814 |
+
original=elem,
|
| 815 |
+
normalized_bbox=snapped_bbox,
|
| 816 |
+
grid_position=grid_position,
|
| 817 |
+
size_category=size_category
|
| 818 |
+
))
|
| 819 |
+
|
| 820 |
+
# Step 5: Apply alignment corrections
|
| 821 |
+
normalized_elements = self._apply_alignment_corrections(
|
| 822 |
+
normalized_elements, h_alignments, v_alignments, edge_alignments
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# Step 6: Resolve overlaps
|
| 826 |
+
overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height)
|
| 827 |
+
normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements)
|
| 828 |
+
|
| 829 |
+
print(f"β
Normalized {len(normalized_elements)} elements")
|
| 830 |
+
return normalized_elements
|
| 831 |
+
|
| 832 |
+
def _apply_alignment_corrections(self, normalized_elements: List[NormalizedElement],
|
| 833 |
+
h_alignments: List[List[Element]],
|
| 834 |
+
v_alignments: List[List[Element]],
|
| 835 |
+
edge_alignments: Dict) -> List[NormalizedElement]:
|
| 836 |
+
"""Apply alignment corrections to normalized elements."""
|
| 837 |
+
|
| 838 |
+
# Create lookup dictionary
|
| 839 |
+
elem_to_normalized = {id(ne.original): ne for ne in normalized_elements}
|
| 840 |
+
|
| 841 |
+
# Align horizontally grouped elements
|
| 842 |
+
for h_group in h_alignments:
|
| 843 |
+
norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized]
|
| 844 |
+
if len(norm_group) > 1:
|
| 845 |
+
# Align to average Y position
|
| 846 |
+
avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group)
|
| 847 |
+
for ne in norm_group:
|
| 848 |
+
height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
|
| 849 |
+
ne.normalized_bbox[1] = avg_y - height / 2
|
| 850 |
+
ne.normalized_bbox[3] = avg_y + height / 2
|
| 851 |
+
|
| 852 |
+
# Align vertically grouped elements
|
| 853 |
+
for v_group in v_alignments:
|
| 854 |
+
norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized]
|
| 855 |
+
if len(norm_group) > 1:
|
| 856 |
+
# Align to average X position
|
| 857 |
+
avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group)
|
| 858 |
+
for ne in norm_group:
|
| 859 |
+
width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
|
| 860 |
+
ne.normalized_bbox[0] = avg_x - width / 2
|
| 861 |
+
ne.normalized_bbox[2] = avg_x + width / 2
|
| 862 |
+
|
| 863 |
+
# Align edges
|
| 864 |
+
for edge_type, groups in edge_alignments.items():
|
| 865 |
+
for edge_group in groups:
|
| 866 |
+
norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized]
|
| 867 |
+
if len(norm_group) > 1:
|
| 868 |
+
if edge_type == 'left':
|
| 869 |
+
avg_left = sum(ne.normalized_bbox[0] for ne in norm_group) / len(norm_group)
|
| 870 |
+
for ne in norm_group:
|
| 871 |
+
width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
|
| 872 |
+
ne.normalized_bbox[0] = avg_left
|
| 873 |
+
ne.normalized_bbox[2] = avg_left + width
|
| 874 |
+
elif edge_type == 'right':
|
| 875 |
+
avg_right = sum(ne.normalized_bbox[2] for ne in norm_group) / len(norm_group)
|
| 876 |
+
for ne in norm_group:
|
| 877 |
+
width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
|
| 878 |
+
ne.normalized_bbox[2] = avg_right
|
| 879 |
+
ne.normalized_bbox[0] = avg_right - width
|
| 880 |
+
elif edge_type == 'top':
|
| 881 |
+
avg_top = sum(ne.normalized_bbox[1] for ne in norm_group) / len(norm_group)
|
| 882 |
+
for ne in norm_group:
|
| 883 |
+
height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
|
| 884 |
+
ne.normalized_bbox[1] = avg_top
|
| 885 |
+
ne.normalized_bbox[3] = avg_top + height
|
| 886 |
+
elif edge_type == 'bottom':
|
| 887 |
+
avg_bottom = sum(ne.normalized_bbox[3] for ne in norm_group) / len(norm_group)
|
| 888 |
+
for ne in norm_group:
|
| 889 |
+
height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
|
| 890 |
+
ne.normalized_bbox[3] = avg_bottom
|
| 891 |
+
ne.normalized_bbox[1] = avg_bottom - height
|
| 892 |
+
|
| 893 |
+
return normalized_elements
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
# ============================================================================
|
| 897 |
+
# VISUALIZATION & EXPORT
|
| 898 |
+
# ============================================================================
|
| 899 |
+
def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
| 900 |
+
normalized_elements: List[NormalizedElement],
|
| 901 |
+
grid_system: GridLayoutSystem):
|
| 902 |
+
"""Visualize original vs normalized layout."""
|
| 903 |
+
|
| 904 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
|
| 905 |
+
|
| 906 |
+
# Original detections
|
| 907 |
+
ax1.imshow(pil_img)
|
| 908 |
+
ax1.set_title("Original Predictions", fontsize=16, weight='bold')
|
| 909 |
+
ax1.axis('off')
|
| 910 |
+
|
| 911 |
+
for elem in elements:
|
| 912 |
+
x1, y1, x2, y2 = elem.bbox
|
| 913 |
+
rect = patches.Rectangle(
|
| 914 |
+
(x1, y1), x2 - x1, y2 - y1,
|
| 915 |
+
linewidth=2, edgecolor='red', facecolor='none'
|
| 916 |
+
)
|
| 917 |
+
ax1.add_patch(rect)
|
| 918 |
+
ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8,
|
| 919 |
+
bbox=dict(facecolor='white', alpha=0.7))
|
| 920 |
+
|
| 921 |
+
# Normalized layout
|
| 922 |
+
ax2.imshow(pil_img)
|
| 923 |
+
ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold')
|
| 924 |
+
ax2.axis('off')
|
| 925 |
+
|
| 926 |
+
# Draw grid
|
| 927 |
+
for x in range(grid_system.num_columns + 1):
|
| 928 |
+
x_pos = x * grid_system.cell_width
|
| 929 |
+
ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
|
| 930 |
+
for y in range(grid_system.num_rows + 1):
|
| 931 |
+
y_pos = y * grid_system.cell_height
|
| 932 |
+
ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
|
| 933 |
+
|
| 934 |
+
# Draw normalized elements
|
| 935 |
+
np.random.seed(42)
|
| 936 |
+
colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))
|
| 937 |
+
color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)}
|
| 938 |
+
|
| 939 |
+
for norm_elem in normalized_elements:
|
| 940 |
+
x1, y1, x2, y2 = norm_elem.normalized_bbox
|
| 941 |
+
color = color_map[norm_elem.original.label]
|
| 942 |
+
|
| 943 |
+
# Normalized box (thick)
|
| 944 |
+
rect = patches.Rectangle(
|
| 945 |
+
(x1, y1), x2 - x1, y2 - y1,
|
| 946 |
+
linewidth=3, edgecolor=color, facecolor='none'
|
| 947 |
+
)
|
| 948 |
+
ax2.add_patch(rect)
|
| 949 |
+
|
| 950 |
+
# Original box (thin, dashed)
|
| 951 |
+
ox1, oy1, ox2, oy2 = norm_elem.original.bbox
|
| 952 |
+
orig_rect = patches.Rectangle(
|
| 953 |
+
(ox1, oy1), ox2 - ox1, oy2 - oy1,
|
| 954 |
+
linewidth=1, edgecolor='gray', facecolor='none',
|
| 955 |
+
linestyle='--', alpha=0.5
|
| 956 |
+
)
|
| 957 |
+
ax2.add_patch(orig_rect)
|
| 958 |
+
|
| 959 |
+
# Label
|
| 960 |
+
grid_pos = norm_elem.grid_position
|
| 961 |
+
label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}"
|
| 962 |
+
ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7,
|
| 963 |
+
bbox=dict(facecolor=color, alpha=0.8, pad=2))
|
| 964 |
+
|
| 965 |
+
plt.tight_layout()
|
| 966 |
+
plt.show()
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def export_to_json(normalized_elements: List[NormalizedElement],
|
| 970 |
+
grid_system: GridLayoutSystem,
|
| 971 |
+
output_path: str):
|
| 972 |
+
"""Export normalized layout to JSON."""
|
| 973 |
+
|
| 974 |
+
output = {
|
| 975 |
+
'metadata': {
|
| 976 |
+
'image_width': grid_system.img_width,
|
| 977 |
+
'image_height': grid_system.img_height,
|
| 978 |
+
'grid_system': {
|
| 979 |
+
'columns': grid_system.num_columns,
|
| 980 |
+
'rows': grid_system.num_rows,
|
| 981 |
+
'cell_width': round(grid_system.cell_width, 2),
|
| 982 |
+
'cell_height': round(grid_system.cell_height, 2)
|
| 983 |
+
},
|
| 984 |
+
'total_elements': len(normalized_elements)
|
| 985 |
+
},
|
| 986 |
+
'elements': []
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
for i, norm_elem in enumerate(normalized_elements):
|
| 990 |
+
orig = norm_elem.original
|
| 991 |
+
norm_bbox = norm_elem.normalized_bbox
|
| 992 |
+
|
| 993 |
+
element_data = {
|
| 994 |
+
'id': i,
|
| 995 |
+
'type': orig.label,
|
| 996 |
+
'confidence': round(orig.score, 3),
|
| 997 |
+
'size_category': norm_elem.size_category,
|
| 998 |
+
'original_bbox': {
|
| 999 |
+
'x1': round(orig.bbox[0], 2),
|
| 1000 |
+
'y1': round(orig.bbox[1], 2),
|
| 1001 |
+
'x2': round(orig.bbox[2], 2),
|
| 1002 |
+
'y2': round(orig.bbox[3], 2),
|
| 1003 |
+
'width': round(orig.width, 2),
|
| 1004 |
+
'height': round(orig.height, 2)
|
| 1005 |
+
},
|
| 1006 |
+
'normalized_bbox': {
|
| 1007 |
+
'x1': round(norm_bbox[0], 2),
|
| 1008 |
+
'y1': round(norm_bbox[1], 2),
|
| 1009 |
+
'x2': round(norm_bbox[2], 2),
|
| 1010 |
+
'y2': round(norm_bbox[3], 2),
|
| 1011 |
+
'width': round(norm_bbox[2] - norm_bbox[0], 2),
|
| 1012 |
+
'height': round(norm_bbox[3] - norm_bbox[1], 2)
|
| 1013 |
+
},
|
| 1014 |
+
'grid_position': norm_elem.grid_position,
|
| 1015 |
+
'percentage': {
|
| 1016 |
+
'x1': round((norm_bbox[0] / grid_system.img_width) * 100, 2),
|
| 1017 |
+
'y1': round((norm_bbox[1] / grid_system.img_height) * 100, 2),
|
| 1018 |
+
'x2': round((norm_bbox[2] / grid_system.img_width) * 100, 2),
|
| 1019 |
+
'y2': round((norm_bbox[3] / grid_system.img_height) * 100, 2)
|
| 1020 |
+
}
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
output['elements'].append(element_data)
|
| 1024 |
+
|
| 1025 |
+
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
|
| 1026 |
+
with open(output_path, 'w') as f:
|
| 1027 |
+
json.dump(output, f, indent=2)
|
| 1028 |
+
|
| 1029 |
+
print(f"\nβ
Exported normalized layout to: {output_path}")
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
def export_to_html(normalized_elements: List[NormalizedElement],
|
| 1033 |
+
grid_system: GridLayoutSystem,
|
| 1034 |
+
output_path: str):
|
| 1035 |
+
"""Export normalized layout as responsive HTML/CSS."""
|
| 1036 |
+
|
| 1037 |
+
html_template = """<!DOCTYPE html>
|
| 1038 |
+
<html lang="en">
|
| 1039 |
+
<head>
|
| 1040 |
+
<meta charset="UTF-8">
|
| 1041 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 1042 |
+
<title>Wireframe Layout</title>
|
| 1043 |
+
<style>
|
| 1044 |
+
* {{
|
| 1045 |
+
margin: 0;
|
| 1046 |
+
padding: 0;
|
| 1047 |
+
box-sizing: border-box;
|
| 1048 |
+
}}
|
| 1049 |
+
|
| 1050 |
+
body {{
|
| 1051 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
|
| 1052 |
+
background: #f5f5f5;
|
| 1053 |
+
padding: 20px;
|
| 1054 |
+
}}
|
| 1055 |
+
|
| 1056 |
+
.container {{
|
| 1057 |
+
max-width: {img_width}px;
|
| 1058 |
+
margin: 0 auto;
|
| 1059 |
+
background: white;
|
| 1060 |
+
position: relative;
|
| 1061 |
+
height: {img_height}px;
|
| 1062 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 1063 |
+
}}
|
| 1064 |
+
|
| 1065 |
+
.element {{
|
| 1066 |
+
position: absolute;
|
| 1067 |
+
border: 2px solid #333;
|
| 1068 |
+
display: flex;
|
| 1069 |
+
align-items: center;
|
| 1070 |
+
justify-content: center;
|
| 1071 |
+
font-size: 12px;
|
| 1072 |
+
color: #666;
|
| 1073 |
+
background: rgba(255,255,255,0.9);
|
| 1074 |
+
transition: all 0.3s ease;
|
| 1075 |
+
}}
|
| 1076 |
+
|
| 1077 |
+
.element:hover {{
|
| 1078 |
+
z-index: 100;
|
| 1079 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.2);
|
| 1080 |
+
transform: scale(1.02);
|
| 1081 |
+
}}
|
| 1082 |
+
|
| 1083 |
+
.element-label {{
|
| 1084 |
+
font-weight: bold;
|
| 1085 |
+
font-size: 10px;
|
| 1086 |
+
text-transform: uppercase;
|
| 1087 |
+
}}
|
| 1088 |
+
|
| 1089 |
+
/* Element type specific styles */
|
| 1090 |
+
.button {{
|
| 1091 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1092 |
+
color: white;
|
| 1093 |
+
border-radius: 6px;
|
| 1094 |
+
font-weight: bold;
|
| 1095 |
+
cursor: pointer;
|
| 1096 |
+
}}
|
| 1097 |
+
|
| 1098 |
+
.checkbox {{
|
| 1099 |
+
background: white;
|
| 1100 |
+
border: 2px solid #4a5568;
|
| 1101 |
+
border-radius: 4px;
|
| 1102 |
+
}}
|
| 1103 |
+
|
| 1104 |
+
.textfield {{
|
| 1105 |
+
background: white;
|
| 1106 |
+
border: 2px solid #cbd5e0;
|
| 1107 |
+
border-radius: 4px;
|
| 1108 |
+
padding: 8px;
|
| 1109 |
+
}}
|
| 1110 |
+
|
| 1111 |
+
.text {{
|
| 1112 |
+
background: transparent;
|
| 1113 |
+
border: 1px dashed #cbd5e0;
|
| 1114 |
+
color: #2d3748;
|
| 1115 |
+
}}
|
| 1116 |
+
|
| 1117 |
+
.paragraph {{
|
| 1118 |
+
background: transparent;
|
| 1119 |
+
border: 1px dashed #cbd5e0;
|
| 1120 |
+
color: #4a5568;
|
| 1121 |
+
text-align: left;
|
| 1122 |
+
padding: 8px;
|
| 1123 |
+
}}
|
| 1124 |
+
|
| 1125 |
+
.image {{
|
| 1126 |
+
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
| 1127 |
+
color: white;
|
| 1128 |
+
border: none;
|
| 1129 |
+
}}
|
| 1130 |
+
|
| 1131 |
+
.navbar {{
|
| 1132 |
+
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
|
| 1133 |
+
color: white;
|
| 1134 |
+
font-weight: bold;
|
| 1135 |
+
border: none;
|
| 1136 |
+
}}
|
| 1137 |
+
|
| 1138 |
+
.info-panel {{
|
| 1139 |
+
position: fixed;
|
| 1140 |
+
top: 20px;
|
| 1141 |
+
right: 20px;
|
| 1142 |
+
background: white;
|
| 1143 |
+
padding: 20px;
|
| 1144 |
+
border-radius: 8px;
|
| 1145 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 1146 |
+
max-width: 300px;
|
| 1147 |
+
}}
|
| 1148 |
+
|
| 1149 |
+
.info-panel h3 {{
|
| 1150 |
+
margin-bottom: 10px;
|
| 1151 |
+
color: #2d3748;
|
| 1152 |
+
}}
|
| 1153 |
+
|
| 1154 |
+
.info-panel p {{
|
| 1155 |
+
margin: 5px 0;
|
| 1156 |
+
font-size: 14px;
|
| 1157 |
+
color: #4a5568;
|
| 1158 |
+
}}
|
| 1159 |
+
</style>
|
| 1160 |
+
</head>
|
| 1161 |
+
<body>
|
| 1162 |
+
<div class="info-panel">
|
| 1163 |
+
<h3>π Layout Info</h3>
|
| 1164 |
+
<p><strong>Grid:</strong> {grid_cols} Γ {grid_rows}</p>
|
| 1165 |
+
<p><strong>Elements:</strong> {total_elements}</p>
|
| 1166 |
+
<p><strong>Dimensions:</strong> {img_width}px Γ {img_height}px</p>
|
| 1167 |
+
<p style="margin-top: 15px; font-size: 12px; color: #718096;">
|
| 1168 |
+
Hover over elements to see details
|
| 1169 |
+
</p>
|
| 1170 |
+
</div>
|
| 1171 |
+
|
| 1172 |
+
<div class="container">
|
| 1173 |
+
{elements_html}
|
| 1174 |
+
</div>
|
| 1175 |
+
</body>
|
| 1176 |
+
</html>"""
|
| 1177 |
+
|
| 1178 |
+
elements_html = []
|
| 1179 |
+
|
| 1180 |
+
for i, norm_elem in enumerate(normalized_elements):
|
| 1181 |
+
x1, y1, x2, y2 = norm_elem.normalized_bbox
|
| 1182 |
+
width = x2 - x1
|
| 1183 |
+
height = y2 - y1
|
| 1184 |
+
|
| 1185 |
+
element_html = f"""
|
| 1186 |
+
<div class="element {norm_elem.original.label}"
|
| 1187 |
+
style="left: {x1}px; top: {y1}px; width: {width}px; height: {height}px;"
|
| 1188 |
+
title="{norm_elem.original.label} | Grid: R{norm_elem.grid_position['start_row']} C{norm_elem.grid_position['start_col']} | Size: {norm_elem.size_category}">
|
| 1189 |
+
<span class="element-label">{norm_elem.original.label}</span>
|
| 1190 |
+
</div>"""
|
| 1191 |
+
|
| 1192 |
+
elements_html.append(element_html)
|
| 1193 |
+
|
| 1194 |
+
html_content = html_template.format(
|
| 1195 |
+
img_width=int(grid_system.img_width),
|
| 1196 |
+
img_height=int(grid_system.img_height),
|
| 1197 |
+
grid_cols=grid_system.num_columns,
|
| 1198 |
+
grid_rows=grid_system.num_rows,
|
| 1199 |
+
total_elements=len(normalized_elements),
|
| 1200 |
+
elements_html='\n'.join(elements_html)
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
|
| 1204 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 1205 |
+
f.write(html_content)
|
| 1206 |
+
|
| 1207 |
+
print(f"β
Exported HTML layout to: {output_path}")
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
# ============================================================================
|
| 1211 |
+
# MAIN PIPELINE
|
| 1212 |
+
# ============================================================================
|
| 1213 |
+
def process_wireframe(image_path: str,
|
| 1214 |
+
save_json: bool = True,
|
| 1215 |
+
save_html: bool = True,
|
| 1216 |
+
show_visualization: bool = True) -> Dict:
|
| 1217 |
+
"""
|
| 1218 |
+
Complete pipeline to process wireframe image.
|
| 1219 |
+
|
| 1220 |
+
Args:
|
| 1221 |
+
image_path: Path to wireframe image
|
| 1222 |
+
save_json: Export normalized layout as JSON
|
| 1223 |
+
save_html: Export normalized layout as HTML
|
| 1224 |
+
show_visualization: Display matplotlib comparison
|
| 1225 |
+
|
| 1226 |
+
Returns:
|
| 1227 |
+
Dictionary containing all processing results
|
| 1228 |
+
"""
|
| 1229 |
+
|
| 1230 |
+
print("=" * 80)
|
| 1231 |
+
print("π WIREFRAME LAYOUT NORMALIZER")
|
| 1232 |
+
print("=" * 80)
|
| 1233 |
+
|
| 1234 |
+
# Step 1: Load model and get predictions
|
| 1235 |
+
global model
|
| 1236 |
+
if model is None:
|
| 1237 |
+
print("\nπ¦ Loading model...")
|
| 1238 |
+
try:
|
| 1239 |
+
model = tf.keras.models.load_model(
|
| 1240 |
+
MODEL_PATH,
|
| 1241 |
+
custom_objects={'LossCalculation': LossCalculation}
|
| 1242 |
+
)
|
| 1243 |
+
print("β
Model loaded successfully!")
|
| 1244 |
+
except Exception as e:
|
| 1245 |
+
print(f"β Error loading model: {e}")
|
| 1246 |
+
print("\nTrying alternative loading method...")
|
| 1247 |
+
try:
|
| 1248 |
+
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
|
| 1249 |
+
print("β
Model loaded successfully (without compilation)!")
|
| 1250 |
+
except Exception as e2:
|
| 1251 |
+
print(f"β Failed to load model: {e2}")
|
| 1252 |
+
return {}
|
| 1253 |
+
|
| 1254 |
+
print(f"\nπΈ Processing image: {image_path}")
|
| 1255 |
+
pil_img, elements = get_predictions(image_path)
|
| 1256 |
+
print(f"β
Detected {len(elements)} elements")
|
| 1257 |
+
|
| 1258 |
+
if not elements:
|
| 1259 |
+
print("β οΈ No elements detected. Exiting.")
|
| 1260 |
+
return {}
|
| 1261 |
+
|
| 1262 |
+
# Step 2: Normalize layout
|
| 1263 |
+
normalizer = LayoutNormalizer(elements, pil_img.width, pil_img.height)
|
| 1264 |
+
normalized_elements = normalizer.normalize_layout()
|
| 1265 |
+
|
| 1266 |
+
# Step 3: Generate outputs
|
| 1267 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 1268 |
+
base_filename = os.path.splitext(os.path.basename(image_path))[0]
|
| 1269 |
+
|
| 1270 |
+
results = {
|
| 1271 |
+
'image': pil_img,
|
| 1272 |
+
'original_elements': elements,
|
| 1273 |
+
'normalized_elements': normalized_elements,
|
| 1274 |
+
'grid_system': normalizer.grid
|
| 1275 |
+
}
|
| 1276 |
+
|
| 1277 |
+
# Export JSON
|
| 1278 |
+
if save_json:
|
| 1279 |
+
json_path = os.path.join(OUTPUT_DIR, f"{base_filename}_normalized.json")
|
| 1280 |
+
export_to_json(normalized_elements, normalizer.grid, json_path)
|
| 1281 |
+
results['json_path'] = json_path
|
| 1282 |
+
|
| 1283 |
+
# Export HTML
|
| 1284 |
+
if save_html:
|
| 1285 |
+
html_path = os.path.join(OUTPUT_DIR, f"{base_filename}_layout.html")
|
| 1286 |
+
export_to_html(normalized_elements, normalizer.grid, html_path)
|
| 1287 |
+
results['html_path'] = html_path
|
| 1288 |
+
|
| 1289 |
+
# Visualize
|
| 1290 |
+
if show_visualization:
|
| 1291 |
+
print("\nπ¨ Generating visualization...")
|
| 1292 |
+
visualize_comparison(pil_img, elements, normalized_elements, normalizer.grid)
|
| 1293 |
+
|
| 1294 |
+
# Print summary
|
| 1295 |
+
print("\n" + "=" * 80)
|
| 1296 |
+
print("π PROCESSING SUMMARY")
|
| 1297 |
+
print("=" * 80)
|
| 1298 |
+
|
| 1299 |
+
# Count by type
|
| 1300 |
+
type_counts = {}
|
| 1301 |
+
for elem in elements:
|
| 1302 |
+
type_counts[elem.label] = type_counts.get(elem.label, 0) + 1
|
| 1303 |
+
|
| 1304 |
+
print(f"\nπ¦ Element Types:")
|
| 1305 |
+
for elem_type, count in sorted(type_counts.items()):
|
| 1306 |
+
print(f" β’ {elem_type}: {count}")
|
| 1307 |
+
|
| 1308 |
+
# Size categories
|
| 1309 |
+
size_categories = {}
|
| 1310 |
+
for norm_elem in normalized_elements:
|
| 1311 |
+
size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1
|
| 1312 |
+
|
| 1313 |
+
print(f"\nπ Size Categories: {len(size_categories)}")
|
| 1314 |
+
|
| 1315 |
+
# Alignment info
|
| 1316 |
+
h_alignments = normalizer.alignment_detector.detect_horizontal_alignments()
|
| 1317 |
+
v_alignments = normalizer.alignment_detector.detect_vertical_alignments()
|
| 1318 |
+
|
| 1319 |
+
print(f"\nπ Alignment:")
|
| 1320 |
+
print(f" β’ Horizontal groups: {len(h_alignments)}")
|
| 1321 |
+
print(f" β’ Vertical groups: {len(v_alignments)}")
|
| 1322 |
+
|
| 1323 |
+
print("\n" + "=" * 80)
|
| 1324 |
+
print("β
PROCESSING COMPLETE!")
|
| 1325 |
+
print("=" * 80 + "\n")
|
| 1326 |
+
|
| 1327 |
+
return results
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
+
def batch_process(image_dir: str, pattern: str = "*.png"):
|
| 1331 |
+
"""Process multiple wireframe images in a directory."""
|
| 1332 |
+
import glob
|
| 1333 |
+
|
| 1334 |
+
image_paths = glob.glob(os.path.join(image_dir, pattern))
|
| 1335 |
+
|
| 1336 |
+
if not image_paths:
|
| 1337 |
+
print(f"β No images found matching pattern: {pattern}")
|
| 1338 |
+
return
|
| 1339 |
+
|
| 1340 |
+
print(f"π Found {len(image_paths)} images to process\n")
|
| 1341 |
+
|
| 1342 |
+
all_results = []
|
| 1343 |
+
for i, image_path in enumerate(image_paths, 1):
|
| 1344 |
+
print(f"\n{'=' * 80}")
|
| 1345 |
+
print(f"Processing image {i}/{len(image_paths)}: {os.path.basename(image_path)}")
|
| 1346 |
+
print(f"{'=' * 80}")
|
| 1347 |
+
|
| 1348 |
+
try:
|
| 1349 |
+
results = process_wireframe(
|
| 1350 |
+
image_path,
|
| 1351 |
+
save_json=True,
|
| 1352 |
+
save_html=True,
|
| 1353 |
+
show_visualization=False
|
| 1354 |
+
)
|
| 1355 |
+
all_results.append({
|
| 1356 |
+
'image_path': image_path,
|
| 1357 |
+
'success': True,
|
| 1358 |
+
'results': results
|
| 1359 |
+
})
|
| 1360 |
+
except Exception as e:
|
| 1361 |
+
print(f"β Error processing {image_path}: {str(e)}")
|
| 1362 |
+
all_results.append({
|
| 1363 |
+
'image_path': image_path,
|
| 1364 |
+
'success': False,
|
| 1365 |
+
'error': str(e)
|
| 1366 |
+
})
|
| 1367 |
+
|
| 1368 |
+
# Summary
|
| 1369 |
+
successful = sum(1 for r in all_results if r['success'])
|
| 1370 |
+
print(f"\n{'=' * 80}")
|
| 1371 |
+
print(f"π BATCH PROCESSING COMPLETE")
|
| 1372 |
+
print(f"{'=' * 80}")
|
| 1373 |
+
print(f"β
Successful: {successful}/{len(image_paths)}")
|
| 1374 |
+
print(f"β Failed: {len(image_paths) - successful}/{len(image_paths)}")
|
| 1375 |
+
|
| 1376 |
+
return all_results
|
| 1377 |
+
|
| 1378 |
+
|
| 1379 |
+
# ============================================================================
|
| 1380 |
+
# EXAMPLE USAGE
|
| 1381 |
+
# ============================================================================
|
| 1382 |
+
if __name__ == "__main__":
|
| 1383 |
+
# Single image processing
|
| 1384 |
+
image_path = "./image/6LHls1vE.jpg"
|
| 1385 |
+
|
| 1386 |
+
# Process with all outputs
|
| 1387 |
+
results = process_wireframe(
|
| 1388 |
+
image_path,
|
| 1389 |
+
save_json=True,
|
| 1390 |
+
save_html=True,
|
| 1391 |
+
show_visualization=True
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
# Or batch process multiple images
|
| 1395 |
+
# batch_results = batch_process("./wireframes/", pattern="*.png")
|