ChristianQ commited on
Commit
f7a075a
Β·
1 Parent(s): a9565d5

Added app.py, detection module, and model

Browse files
Files changed (2) hide show
  1. app.py +121 -0
  2. visualization.py +1395 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import shutil
4
+ import traceback
5
+
6
+ import gradio as gr
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from fastapi.responses import JSONResponse
9
+
10
+ from visualization import process_wireframe
11
+
12
+ # -----------------------------------------------------------------------------
13
+ # FASTAPI (for Firebase / programmatic access)
14
+ # -----------------------------------------------------------------------------
15
+ api = FastAPI()
16
+
17
+ TEMP_DIR = "./temp"
18
+ OUTPUT_DIR = "./output"
19
+
20
+ os.makedirs(TEMP_DIR, exist_ok=True)
21
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
22
+
23
+
24
+ @api.get("/")
25
+ def health_check():
26
+ return {"status": "ok"}
27
+
28
+
29
+ @api.post("/process-wireframe")
30
+ async def process_wireframe_api(image: UploadFile = File(...)):
31
+ file_id = str(uuid.uuid4())
32
+ temp_path = os.path.join(TEMP_DIR, f"{file_id}_{image.filename}")
33
+
34
+ try:
35
+ with open(temp_path, "wb") as f:
36
+ shutil.copyfileobj(image.file, f)
37
+
38
+ results = process_wireframe(
39
+ image_path=temp_path,
40
+ save_json=True,
41
+ save_html=True,
42
+ show_visualization=False
43
+ )
44
+
45
+ if not results:
46
+ return JSONResponse(
47
+ status_code=400,
48
+ content={"error": "No elements detected"}
49
+ )
50
+
51
+ return {
52
+ "success": True,
53
+ "json_path": results.get("json_path"),
54
+ "html_path": results.get("html_path"),
55
+ "total_elements": len(results["normalized_elements"])
56
+ }
57
+
58
+ except Exception as e:
59
+ traceback.print_exc()
60
+ return JSONResponse(
61
+ status_code=500,
62
+ content={"error": str(e)}
63
+ )
64
+
65
+ finally:
66
+ if os.path.exists(temp_path):
67
+ os.remove(temp_path)
68
+
69
+
70
+ # -----------------------------------------------------------------------------
71
+ # GRADIO (for Hugging Face UI)
72
+ # -----------------------------------------------------------------------------
73
+ def gradio_process(image):
74
+ """
75
+ Gradio passes a PIL Image.
76
+ We save it temporarily and reuse the SAME pipeline.
77
+ """
78
+ temp_path = f"{TEMP_DIR}/{uuid.uuid4()}.png"
79
+ image.save(temp_path)
80
+
81
+ try:
82
+ results = process_wireframe(
83
+ image_path=temp_path,
84
+ save_json=True,
85
+ save_html=True,
86
+ show_visualization=False
87
+ )
88
+
89
+ if not results:
90
+ return "No elements detected", None
91
+
92
+ return (
93
+ f"Detected {len(results['normalized_elements'])} elements",
94
+ results.get("json_path")
95
+ )
96
+
97
+ except Exception as e:
98
+ traceback.print_exc()
99
+ return f"Error: {str(e)}", None
100
+
101
+ finally:
102
+ if os.path.exists(temp_path):
103
+ os.remove(temp_path)
104
+
105
+
106
+ demo = gr.Interface(
107
+ fn=gradio_process,
108
+ inputs=gr.Image(type="pil", label="Upload Wireframe"),
109
+ outputs=[
110
+ gr.Textbox(label="Status"),
111
+ gr.File(label="Normalized JSON Output")
112
+ ],
113
+ title="Wireframe Layout Normalizer",
114
+ description="Upload a wireframe image to extract and normalize UI layout"
115
+ )
116
+
117
+ # -----------------------------------------------------------------------------
118
+ # ENTRY POINT (THIS IS IMPORTANT)
119
+ # -----------------------------------------------------------------------------
120
+ app = gr.mount_gradio_app(api, demo, path="/")
121
+
visualization.py ADDED
@@ -0,0 +1,1395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.patches as patches
5
+ from PIL import Image, ImageOps
6
+ import json
7
+ import os
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import List, Tuple, Dict, Optional
11
+
12
+
13
+ # ============================================================================
14
+ # CUSTOM LOSS CLASS (Required for model loading)
15
+ # ============================================================================
16
+ @tf.keras.utils.register_keras_serializable()
17
+ class LossCalculation(tf.keras.losses.Loss):
18
+ """Custom loss function for wireframe detection."""
19
+
20
+ def __init__(self, num_classes=7, lambda_coord=5.0, lambda_noobj=0.5,
21
+ name='loss_calculation', reduction='sum_over_batch_size', **kwargs):
22
+ super().__init__(name=name, reduction=reduction)
23
+ self.num_classes = num_classes
24
+ self.lambda_coord = lambda_coord
25
+ self.lambda_noobj = lambda_noobj
26
+
27
+ def call(self, y_true, y_pred):
28
+ obj_true = y_true[..., 0]
29
+ box_true = y_true[..., 1:5]
30
+ cls_true = y_true[..., 5:]
31
+
32
+ obj_pred_logits = y_pred[..., 0]
33
+ box_pred = y_pred[..., 1:5]
34
+ cls_pred_logits = y_pred[..., 5:]
35
+
36
+ obj_mask = tf.cast(obj_true > 0.5, tf.float32)
37
+ noobj_mask = 1.0 - obj_mask
38
+ num_pos = tf.maximum(tf.reduce_sum(obj_mask), 1.0)
39
+
40
+ obj_loss_pos = obj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
41
+ labels=obj_true, logits=obj_pred_logits)
42
+ obj_loss_neg = noobj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
43
+ labels=obj_true, logits=obj_pred_logits)
44
+ obj_loss = (tf.reduce_sum(obj_loss_pos) + self.lambda_noobj * tf.reduce_sum(obj_loss_neg)) / tf.cast(
45
+ tf.size(obj_true), tf.float32)
46
+
47
+ xy_pred = tf.nn.sigmoid(box_pred[..., 0:2])
48
+ wh_pred = tf.nn.sigmoid(box_pred[..., 2:4])
49
+ xy_true = box_true[..., 0:2]
50
+ wh_true = box_true[..., 2:4]
51
+
52
+ xy_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(xy_true - xy_pred)) / num_pos
53
+ wh_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(wh_true - wh_pred)) / num_pos
54
+ box_loss = self.lambda_coord * (xy_loss + wh_loss)
55
+
56
+ cls_loss = tf.reduce_sum(obj_mask * tf.nn.softmax_cross_entropy_with_logits(
57
+ labels=cls_true, logits=cls_pred_logits)) / num_pos
58
+
59
+ total_loss = obj_loss + box_loss + cls_loss
60
+ return tf.clip_by_value(total_loss, 0.0, 100.0)
61
+
62
+ def _smooth_l1_loss(self, x, beta=1.0):
63
+ abs_x = tf.abs(x)
64
+ return tf.where(abs_x < beta, 0.5 * x * x / beta, abs_x - 0.5 * beta)
65
+
66
+ def get_config(self):
67
+ config = super().get_config()
68
+ config.update({
69
+ 'num_classes': self.num_classes,
70
+ 'lambda_coord': self.lambda_coord,
71
+ 'lambda_noobj': self.lambda_noobj,
72
+ })
73
+ return config
74
+
75
+ @classmethod
76
+ def from_config(cls, config):
77
+ return cls(**config)
78
+
79
+
80
+ # ============================================================================
81
+ # CONFIGURATION - UPDATED FOR BETTER PRECISION
82
+ # ============================================================================
83
+ MODEL_PATH = "./wireframe_detection_model_best_700.keras"
84
+ OUTPUT_DIR = "./output/"
85
+ CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"]
86
+
87
+ IMG_SIZE = 416
88
+ CONF_THRESHOLD = 0.1
89
+ IOU_THRESHOLD = 0.1
90
+
91
+ # Layout Configuration - INCREASED GRID DENSITY
92
+ GRID_COLUMNS = 24 # Doubled from 12 for finer precision
93
+ ALIGNMENT_THRESHOLD = 10 # Reduced from 15 for tighter alignment
94
+ SIZE_CLUSTERING_THRESHOLD = 15 # Reduced from 20 for better size grouping
95
+
96
+ # Standard sizes for each element type (relative units) - UPDATED FOR SMALLER BUTTONS/CHECKBOXES
97
+ STANDARD_SIZES = {
98
+ 'button': {'width': 2, 'height': 1}, # Smaller button (was 2x1, now in finer grid)
99
+ 'checkbox': {'width': 1, 'height': 1}, # Keep small checkbox
100
+ 'textfield': {'width': 5, 'height': 1}, # Adjusted for new grid
101
+ 'text': {'width': 3, 'height': 1}, # Adjusted
102
+ 'paragraph': {'width': 8, 'height': 2}, # Adjusted
103
+ 'image': {'width': 4, 'height': 4}, # Adjusted
104
+ 'navbar': {'width': 24, 'height': 1} # Full width in new grid
105
+ }
106
+
107
+ model = None
108
+
109
+
110
+ # ============================================================================
111
+ # DATA STRUCTURES
112
+ # ============================================================================
113
+ @dataclass
114
+ class Element:
115
+ """Represents a detected UI element."""
116
+ label: str
117
+ score: float
118
+ bbox: List[float] # [x1, y1, x2, y2]
119
+ width: float = 0
120
+ height: float = 0
121
+ center_x: float = 0
122
+ center_y: float = 0
123
+
124
+ def __post_init__(self):
125
+ self.width = self.bbox[2] - self.bbox[0]
126
+ self.height = self.bbox[3] - self.bbox[1]
127
+ self.center_x = (self.bbox[0] + self.bbox[2]) / 2
128
+ self.center_y = (self.bbox[1] + self.bbox[3]) / 2
129
+
130
+
131
+ @dataclass
132
+ class NormalizedElement:
133
+ """Represents a normalized UI element."""
134
+ original: Element
135
+ normalized_bbox: List[float]
136
+ grid_position: Dict
137
+ size_category: str
138
+ alignment_group: Optional[int] = None
139
+
140
+
141
+ # ============================================================================
142
+ # PREDICTION EXTRACTION
143
+ # ============================================================================
144
+ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
145
+ """Extract predictions from the model."""
146
+ global model
147
+ if model is None:
148
+ raise ValueError("Model not loaded. Please load the model first.")
149
+
150
+ # Load and preprocess image
151
+ pil_img = Image.open(image_path).convert("RGB")
152
+ pil_img = ImageOps.exif_transpose(pil_img)
153
+ orig_w, orig_h = pil_img.size
154
+ resized_img = pil_img.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
155
+ img_array = np.array(resized_img, dtype=np.float32) / 255.0
156
+ input_tensor = np.expand_dims(img_array, axis=0)
157
+
158
+ # Get predictions
159
+ pred_grid = model.predict(input_tensor, verbose=0)[0]
160
+ raw_boxes = []
161
+ S = pred_grid.shape[0]
162
+ cell_size = 1.0 / S
163
+
164
+ for row in range(S):
165
+ for col in range(S):
166
+ obj_score = float(tf.nn.sigmoid(pred_grid[row, col, 0]))
167
+ if obj_score < CONF_THRESHOLD:
168
+ continue
169
+
170
+ x_offset = float(tf.nn.sigmoid(pred_grid[row, col, 1]))
171
+ y_offset = float(tf.nn.sigmoid(pred_grid[row, col, 2]))
172
+ width = float(tf.nn.sigmoid(pred_grid[row, col, 3]))
173
+ height = float(tf.nn.sigmoid(pred_grid[row, col, 4]))
174
+
175
+ class_logits = pred_grid[row, col, 5:]
176
+ class_probs = tf.nn.softmax(class_logits).numpy()
177
+ class_id = int(np.argmax(class_probs))
178
+ class_conf = float(class_probs[class_id])
179
+ final_score = obj_score * class_conf
180
+
181
+ if final_score < CONF_THRESHOLD:
182
+ continue
183
+
184
+ center_x = (col + x_offset) * cell_size
185
+ center_y = (row + y_offset) * cell_size
186
+ x1 = (center_x - width / 2) * orig_w
187
+ y1 = (center_y - height / 2) * orig_h
188
+ x2 = (center_x + width / 2) * orig_w
189
+ y2 = (center_y + height / 2) * orig_h
190
+
191
+ if x2 > x1 and y2 > y1:
192
+ raw_boxes.append((class_id, final_score, x1, y1, x2, y2))
193
+
194
+ # Apply NMS per class
195
+ elements = []
196
+ for class_id in range(len(CLASS_NAMES)):
197
+ class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id]
198
+ if not class_boxes:
199
+ continue
200
+
201
+ scores = [b[0] for b in class_boxes]
202
+ boxes_xyxy = [[b[1], b[2], b[3], b[4]] for b in class_boxes]
203
+
204
+ selected_indices = tf.image.non_max_suppression(
205
+ boxes=boxes_xyxy,
206
+ scores=scores,
207
+ max_output_size=50,
208
+ iou_threshold=IOU_THRESHOLD,
209
+ score_threshold=CONF_THRESHOLD
210
+ )
211
+
212
+ for idx in selected_indices.numpy():
213
+ score, x1, y1, x2, y2 = class_boxes[idx]
214
+ elements.append(Element(
215
+ label=CLASS_NAMES[class_id],
216
+ score=float(score),
217
+ bbox=[float(x1), float(y1), float(x2), float(y2)]
218
+ ))
219
+
220
+ return pil_img, elements
221
+
222
+
223
+ # ============================================================================
224
+ # ALIGNMENT DETECTION
225
+ # ============================================================================
226
+ class AlignmentDetector:
227
+ """Detects alignment relationships between elements."""
228
+
229
+ def __init__(self, elements: List[Element], threshold: float = ALIGNMENT_THRESHOLD):
230
+ self.elements = elements
231
+ self.threshold = threshold
232
+
233
+ def detect_horizontal_alignments(self) -> List[List[Element]]:
234
+ """Group elements that are horizontally aligned (same Y position)."""
235
+ if not self.elements:
236
+ return []
237
+
238
+ sorted_elements = sorted(self.elements, key=lambda e: e.center_y)
239
+ groups = []
240
+ current_group = [sorted_elements[0]]
241
+
242
+ for elem in sorted_elements[1:]:
243
+ avg_y = sum(e.center_y for e in current_group) / len(current_group)
244
+ if abs(elem.center_y - avg_y) <= self.threshold:
245
+ current_group.append(elem)
246
+ else:
247
+ if len(current_group) > 1:
248
+ current_group.sort(key=lambda e: e.center_x)
249
+ groups.append(current_group)
250
+ current_group = [elem]
251
+
252
+ if len(current_group) > 1:
253
+ current_group.sort(key=lambda e: e.center_x)
254
+ groups.append(current_group)
255
+
256
+ return groups
257
+
258
+ def detect_vertical_alignments(self) -> List[List[Element]]:
259
+ """Group elements that are vertically aligned (same X position)."""
260
+ if not self.elements:
261
+ return []
262
+
263
+ sorted_elements = sorted(self.elements, key=lambda e: e.center_x)
264
+ groups = []
265
+ current_group = [sorted_elements[0]]
266
+
267
+ for elem in sorted_elements[1:]:
268
+ avg_x = sum(e.center_x for e in current_group) / len(current_group)
269
+ if abs(elem.center_x - avg_x) <= self.threshold:
270
+ current_group.append(elem)
271
+ else:
272
+ if len(current_group) > 1:
273
+ current_group.sort(key=lambda e: e.center_y)
274
+ groups.append(current_group)
275
+ current_group = [elem]
276
+
277
+ if len(current_group) > 1:
278
+ current_group.sort(key=lambda e: e.center_y)
279
+ groups.append(current_group)
280
+
281
+ return groups
282
+
283
+ def detect_edge_alignments(self) -> Dict[str, List[List[Element]]]:
284
+ """Detect elements with aligned edges (left, right, top, bottom)."""
285
+ alignments = {
286
+ 'left': [],
287
+ 'right': [],
288
+ 'top': [],
289
+ 'bottom': []
290
+ }
291
+
292
+ if not self.elements:
293
+ return alignments
294
+
295
+ sorted_left = sorted(self.elements, key=lambda e: e.bbox[0])
296
+ alignments['left'] = self._cluster_by_value(sorted_left, lambda e: e.bbox[0])
297
+
298
+ sorted_right = sorted(self.elements, key=lambda e: e.bbox[2])
299
+ alignments['right'] = self._cluster_by_value(sorted_right, lambda e: e.bbox[2])
300
+
301
+ sorted_top = sorted(self.elements, key=lambda e: e.bbox[1])
302
+ alignments['top'] = self._cluster_by_value(sorted_top, lambda e: e.bbox[1])
303
+
304
+ sorted_bottom = sorted(self.elements, key=lambda e: e.bbox[3])
305
+ alignments['bottom'] = self._cluster_by_value(sorted_bottom, lambda e: e.bbox[3])
306
+
307
+ return alignments
308
+
309
+ def _cluster_by_value(self, elements: List[Element], value_func) -> List[List[Element]]:
310
+ """Cluster elements by a value function within threshold."""
311
+ if not elements:
312
+ return []
313
+
314
+ groups = []
315
+ current_group = [elements[0]]
316
+ current_value = value_func(elements[0])
317
+
318
+ for elem in elements[1:]:
319
+ elem_value = value_func(elem)
320
+ if abs(elem_value - current_value) <= self.threshold:
321
+ current_group.append(elem)
322
+ current_value = (current_value * (len(current_group) - 1) + elem_value) / len(current_group)
323
+ else:
324
+ if len(current_group) > 1:
325
+ groups.append(current_group)
326
+ current_group = [elem]
327
+ current_value = elem_value
328
+
329
+ if len(current_group) > 1:
330
+ groups.append(current_group)
331
+
332
+ return groups
333
+
334
+
335
+ # ============================================================================
336
+ # SIZE NORMALIZATION - UPDATED TO RESPECT ACTUAL SIZES MORE
337
+ # ============================================================================
338
+ class SizeNormalizer:
339
+ """Normalizes element sizes based on type and clustering."""
340
+
341
+ def __init__(self, elements: List[Element], img_width: float, img_height: float):
342
+ self.elements = elements
343
+ self.img_width = img_width
344
+ self.img_height = img_height
345
+ self.size_clusters = {}
346
+
347
+ def cluster_sizes_by_type(self) -> Dict[str, List[List[Element]]]:
348
+ """Cluster elements of same type by similar sizes."""
349
+ clusters_by_type = {}
350
+
351
+ for label in CLASS_NAMES:
352
+ type_elements = [e for e in self.elements if e.label == label]
353
+ if not type_elements:
354
+ continue
355
+
356
+ width_clusters = self._cluster_by_dimension(type_elements, 'width')
357
+ final_clusters = []
358
+ for width_cluster in width_clusters:
359
+ height_clusters = self._cluster_by_dimension(width_cluster, 'height')
360
+ final_clusters.extend(height_clusters)
361
+
362
+ clusters_by_type[label] = final_clusters
363
+
364
+ return clusters_by_type
365
+
366
+ def _cluster_by_dimension(self, elements: List[Element], dimension: str) -> List[List[Element]]:
367
+ """Cluster elements by width or height."""
368
+ if not elements:
369
+ return []
370
+
371
+ sorted_elements = sorted(elements, key=lambda e: getattr(e, dimension))
372
+ clusters = []
373
+ current_cluster = [sorted_elements[0]]
374
+
375
+ for elem in sorted_elements[1:]:
376
+ avg_dim = sum(getattr(e, dimension) for e in current_cluster) / len(current_cluster)
377
+ if abs(getattr(elem, dimension) - avg_dim) <= SIZE_CLUSTERING_THRESHOLD:
378
+ current_cluster.append(elem)
379
+ else:
380
+ clusters.append(current_cluster)
381
+ current_cluster = [elem]
382
+
383
+ clusters.append(current_cluster)
384
+ return clusters
385
+
386
+ def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]:
387
+ """Get normalized size for an element based on its cluster - PRESERVES ACTUAL SIZE BETTER."""
388
+ # Use the actual detected size instead of aggressive averaging
389
+ # Only normalize if there's a significant cluster
390
+ if len(size_cluster) >= 3:
391
+ # Use median instead of mean to avoid outliers
392
+ widths = sorted([e.width for e in size_cluster])
393
+ heights = sorted([e.height for e in size_cluster])
394
+ median_width = widths[len(widths) // 2]
395
+ median_height = heights[len(heights) // 2]
396
+
397
+ # Only normalize if element is within 30% of median
398
+ if abs(element.width - median_width) / median_width < 0.3:
399
+ normalized_width = round(median_width)
400
+ else:
401
+ normalized_width = round(element.width)
402
+
403
+ if abs(element.height - median_height) / median_height < 0.3:
404
+ normalized_height = round(median_height)
405
+ else:
406
+ normalized_height = round(element.height)
407
+ else:
408
+ # Small cluster - keep original size
409
+ normalized_width = round(element.width)
410
+ normalized_height = round(element.height)
411
+
412
+ return normalized_width, normalized_height
413
+
414
+
415
+ # ============================================================================
416
+ # GRID-BASED LAYOUT SYSTEM - UPDATED FOR FINER PRECISION
417
+ # ============================================================================
418
+ class GridLayoutSystem:
419
+ """Grid-based layout system for precise positioning."""
420
+
421
+ def __init__(self, img_width: float, img_height: float, num_columns: int = GRID_COLUMNS):
422
+ self.img_width = img_width
423
+ self.img_height = img_height
424
+ self.num_columns = num_columns
425
+
426
+ cell_width = img_width / num_columns
427
+ self.num_rows = max(1, int(img_height / cell_width))
428
+ self.cell_width = img_width / num_columns
429
+ self.cell_height = img_height / self.num_rows
430
+
431
+ print(f"πŸ“ Grid System: {self.num_columns} columns Γ— {self.num_rows} rows")
432
+ print(f"πŸ“ Cell size: {self.cell_width:.1f}px Γ— {self.cell_height:.1f}px")
433
+
434
+ def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]:
435
+ """Snap bounding box to grid - UPDATED TO PRESERVE ORIGINAL SIZE BETTER."""
436
+ x1, y1, x2, y2 = bbox
437
+ original_width = x2 - x1
438
+ original_height = y2 - y1
439
+
440
+ # Calculate center
441
+ center_x = (x1 + x2) / 2
442
+ center_y = (y1 + y2) / 2
443
+
444
+ # Find nearest grid cell for center
445
+ center_col = round(center_x / self.cell_width)
446
+ center_row = round(center_y / self.cell_height)
447
+
448
+ if preserve_size:
449
+ # Calculate span based on actual size (don't force to standard)
450
+ width_cells = max(1, round(original_width / self.cell_width))
451
+ height_cells = max(1, round(original_height / self.cell_height))
452
+ else:
453
+ # Use standard size
454
+ standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1})
455
+ width_cells = max(1, round(original_width / self.cell_width))
456
+ height_cells = max(1, round(original_height / self.cell_height))
457
+
458
+ # Only adjust to standard if very close
459
+ if abs(width_cells - standard['width']) <= 0.5:
460
+ width_cells = standard['width']
461
+ if abs(height_cells - standard['height']) <= 0.5:
462
+ height_cells = standard['height']
463
+
464
+ # Calculate start position (center the element)
465
+ start_col = center_col - width_cells // 2
466
+ start_row = center_row - height_cells // 2
467
+
468
+ # Clamp to grid bounds
469
+ start_col = max(0, min(start_col, self.num_columns - width_cells))
470
+ start_row = max(0, min(start_row, self.num_rows - height_cells))
471
+
472
+ # Convert back to pixels
473
+ snapped_x1 = start_col * self.cell_width
474
+ snapped_y1 = start_row * self.cell_height
475
+ snapped_x2 = (start_col + width_cells) * self.cell_width
476
+ snapped_y2 = (start_row + height_cells) * self.cell_height
477
+
478
+ return [snapped_x1, snapped_y1, snapped_x2, snapped_y2]
479
+
480
+ def get_grid_position(self, bbox: List[float]) -> Dict:
481
+ """Get grid position information for a bounding box."""
482
+ x1, y1, x2, y2 = bbox
483
+
484
+ start_col = int(x1 / self.cell_width)
485
+ start_row = int(y1 / self.cell_height)
486
+ end_col = int(np.ceil(x2 / self.cell_width))
487
+ end_row = int(np.ceil(y2 / self.cell_height))
488
+
489
+ return {
490
+ 'start_row': start_row,
491
+ 'end_row': end_row,
492
+ 'start_col': start_col,
493
+ 'end_col': end_col,
494
+ 'rowspan': end_row - start_row,
495
+ 'colspan': end_col - start_col
496
+ }
497
+
498
+
499
+ # ============================================================================
500
+ # OVERLAP DETECTION & RESOLUTION - UPDATED WITH BETTER STRATEGIES
501
+ # ============================================================================
502
+ class OverlapResolver:
503
+ """Detects and resolves overlapping elements."""
504
+
505
+ def __init__(self, elements: List[Element], img_width: float, img_height: float):
506
+ self.elements = elements
507
+ self.img_width = img_width
508
+ self.img_height = img_height
509
+ self.overlap_threshold = 0.2 # Reduced from 0.3 - be more aggressive
510
+
511
+ def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
512
+ """Compute Intersection over Union between two bounding boxes."""
513
+ x1 = max(bbox1[0], bbox2[0])
514
+ y1 = max(bbox1[1], bbox2[1])
515
+ x2 = min(bbox1[2], bbox2[2])
516
+ y2 = min(bbox1[3], bbox2[3])
517
+
518
+ if x2 <= x1 or y2 <= y1:
519
+ return 0.0
520
+
521
+ intersection = (x2 - x1) * (y2 - y1)
522
+ area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
523
+ area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
524
+ union = area1 + area2 - intersection
525
+
526
+ return intersection / union if union > 0 else 0.0
527
+
528
+ def compute_overlap_ratio(self, bbox1: List[float], bbox2: List[float]) -> Tuple[float, float]:
529
+ """Compute what percentage of each box overlaps with the other."""
530
+ x1 = max(bbox1[0], bbox2[0])
531
+ y1 = max(bbox1[1], bbox2[1])
532
+ x2 = min(bbox1[2], bbox2[2])
533
+ y2 = min(bbox1[3], bbox2[3])
534
+
535
+ if x2 <= x1 or y2 <= y1:
536
+ return 0.0, 0.0
537
+
538
+ intersection = (x2 - x1) * (y2 - y1)
539
+ area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
540
+ area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
541
+
542
+ overlap_ratio1 = intersection / area1 if area1 > 0 else 0.0
543
+ overlap_ratio2 = intersection / area2 if area2 > 0 else 0.0
544
+
545
+ return overlap_ratio1, overlap_ratio2
546
+
547
+ def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]:
548
+ """Resolve overlaps by adjusting element positions - IMPROVED ALGORITHM."""
549
+ print("\nπŸ” Checking for overlaps...")
550
+
551
+ overlaps = []
552
+ for i in range(len(normalized_elements)):
553
+ for j in range(i + 1, len(normalized_elements)):
554
+ ne1 = normalized_elements[i]
555
+ ne2 = normalized_elements[j]
556
+
557
+ iou = self.compute_iou(ne1.normalized_bbox, ne2.normalized_bbox)
558
+ if iou > 0:
559
+ overlap1, overlap2 = self.compute_overlap_ratio(
560
+ ne1.normalized_bbox, ne2.normalized_bbox
561
+ )
562
+ max_overlap = max(overlap1, overlap2)
563
+
564
+ if max_overlap >= self.overlap_threshold:
565
+ overlaps.append({
566
+ 'idx1': i,
567
+ 'idx2': j,
568
+ 'elem1': ne1,
569
+ 'elem2': ne2,
570
+ 'overlap': max_overlap,
571
+ 'overlap1': overlap1,
572
+ 'overlap2': overlap2,
573
+ 'iou': iou
574
+ })
575
+
576
+ if not overlaps:
577
+ print("βœ… No significant overlaps detected")
578
+ return normalized_elements
579
+
580
+ print(f"⚠️ Found {len(overlaps)} overlapping element pairs")
581
+
582
+ # Sort by overlap severity
583
+ overlaps.sort(key=lambda x: x['overlap'], reverse=True)
584
+
585
+ elements_to_remove = set()
586
+
587
+ for overlap_info in overlaps:
588
+ idx1 = overlap_info['idx1']
589
+ idx2 = overlap_info['idx2']
590
+
591
+ if idx1 in elements_to_remove or idx2 in elements_to_remove:
592
+ continue
593
+
594
+ elem1 = overlap_info['elem1']
595
+ elem2 = overlap_info['elem2']
596
+ overlap_ratio = overlap_info['overlap']
597
+
598
+ # Strategy 1: Nearly complete overlap (>70%) - remove lower confidence
599
+ if overlap_ratio > 0.7:
600
+ if elem1.original.score < elem2.original.score:
601
+ elements_to_remove.add(idx1)
602
+ print(f" πŸ—‘οΈ Removing {elem1.original.label} (conf: {elem1.original.score:.2f}) - "
603
+ f"overlaps {overlap_ratio * 100:.1f}% with {elem2.original.label}")
604
+ else:
605
+ elements_to_remove.add(idx2)
606
+ print(f" πŸ—‘οΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - "
607
+ f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}")
608
+
609
+ # Strategy 2: Significant overlap (40-70%) - try to separate
610
+ elif overlap_ratio > 0.4:
611
+ self._try_separate_elements(elem1, elem2, overlap_info)
612
+ print(f" ↔️ Separating {elem1.original.label} and {elem2.original.label} "
613
+ f"(overlap: {overlap_ratio * 100:.1f}%)")
614
+
615
+ # Strategy 3: Moderate overlap (20-40%) - shrink slightly
616
+ else:
617
+ self._shrink_overlapping_edges(elem1, elem2, overlap_info)
618
+ print(f" πŸ“ Shrinking {elem1.original.label} and {elem2.original.label} "
619
+ f"(overlap: {overlap_ratio * 100:.1f}%)")
620
+
621
+ if elements_to_remove:
622
+ normalized_elements = [
623
+ ne for i, ne in enumerate(normalized_elements)
624
+ if i not in elements_to_remove
625
+ ]
626
+ print(f"βœ… Removed {len(elements_to_remove)} completely overlapping elements")
627
+
628
+ return normalized_elements
629
+
630
+ def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement,
631
+ overlap_info: Dict):
632
+ """Try to separate two significantly overlapping elements - IMPROVED."""
633
+ bbox1 = elem1.normalized_bbox
634
+ bbox2 = elem2.normalized_bbox
635
+
636
+ # Calculate overlap dimensions
637
+ overlap_x1 = max(bbox1[0], bbox2[0])
638
+ overlap_y1 = max(bbox1[1], bbox2[1])
639
+ overlap_x2 = min(bbox1[2], bbox2[2])
640
+ overlap_y2 = min(bbox1[3], bbox2[3])
641
+
642
+ overlap_width = overlap_x2 - overlap_x1
643
+ overlap_height = overlap_y2 - overlap_y1
644
+
645
+ # Calculate centers
646
+ center1_x = (bbox1[0] + bbox1[2]) / 2
647
+ center1_y = (bbox1[1] + bbox1[3]) / 2
648
+ center2_x = (bbox2[0] + bbox2[2]) / 2
649
+ center2_y = (bbox2[1] + bbox2[3]) / 2
650
+
651
+ # Determine separation direction
652
+ dx = abs(center2_x - center1_x)
653
+ dy = abs(center2_y - center1_y)
654
+
655
+ # Add minimum gap
656
+ min_gap = 3 # pixels
657
+
658
+ if dx > dy:
659
+ # Separate horizontally
660
+ if center1_x < center2_x:
661
+ # elem1 is left of elem2
662
+ midpoint = (bbox1[2] + bbox2[0]) / 2
663
+ bbox1[2] = midpoint - min_gap
664
+ bbox2[0] = midpoint + min_gap
665
+ else:
666
+ # elem2 is left of elem1
667
+ midpoint = (bbox2[2] + bbox1[0]) / 2
668
+ bbox2[2] = midpoint - min_gap
669
+ bbox1[0] = midpoint + min_gap
670
+ else:
671
+ # Separate vertically
672
+ if center1_y < center2_y:
673
+ # elem1 is above elem2
674
+ midpoint = (bbox1[3] + bbox2[1]) / 2
675
+ bbox1[3] = midpoint - min_gap
676
+ bbox2[1] = midpoint + min_gap
677
+ else:
678
+ # elem2 is above elem1
679
+ midpoint = (bbox2[3] + bbox1[1]) / 2
680
+ bbox2[3] = midpoint - min_gap
681
+ bbox1[1] = midpoint + min_gap
682
+
683
+ # Ensure boxes remain valid
684
+ self._ensure_valid_bbox(bbox1)
685
+ self._ensure_valid_bbox(bbox2)
686
+
687
+ def _shrink_overlapping_edges(self, elem1: NormalizedElement, elem2: NormalizedElement,
688
+ overlap_info: Dict):
689
+ """Shrink overlapping edges for moderate overlaps."""
690
+ bbox1 = elem1.normalized_bbox
691
+ bbox2 = elem2.normalized_bbox
692
+
693
+ # Calculate overlap region
694
+ overlap_x1 = max(bbox1[0], bbox2[0])
695
+ overlap_y1 = max(bbox1[1], bbox2[1])
696
+ overlap_x2 = min(bbox1[2], bbox2[2])
697
+ overlap_y2 = min(bbox1[3], bbox2[3])
698
+
699
+ overlap_width = overlap_x2 - overlap_x1
700
+ overlap_height = overlap_y2 - overlap_y1
701
+
702
+ # Shrink by 50% of overlap plus small gap
703
+ gap = 2 # pixels
704
+
705
+ if overlap_width > overlap_height:
706
+ # Horizontal overlap is larger
707
+ shrink = overlap_width / 2 + gap
708
+ if bbox1[0] < bbox2[0]:
709
+ bbox1[2] -= shrink
710
+ bbox2[0] += shrink
711
+ else:
712
+ bbox2[2] -= shrink
713
+ bbox1[0] += shrink
714
+ else:
715
+ # Vertical overlap is larger
716
+ shrink = overlap_height / 2 + gap
717
+ if bbox1[1] < bbox2[1]:
718
+ bbox1[3] -= shrink
719
+ bbox2[1] += shrink
720
+ else:
721
+ bbox2[3] -= shrink
722
+ bbox1[1] += shrink
723
+
724
+ self._ensure_valid_bbox(bbox1)
725
+ self._ensure_valid_bbox(bbox2)
726
+
727
+ def _ensure_valid_bbox(self, bbox: List[float]):
728
+ """Ensure bounding box has minimum size and is within image bounds."""
729
+ min_size = 8 # Reduced minimum size
730
+
731
+ # Ensure minimum size
732
+ if bbox[2] - bbox[0] < min_size:
733
+ center_x = (bbox[0] + bbox[2]) / 2
734
+ bbox[0] = center_x - min_size / 2
735
+ bbox[2] = center_x + min_size / 2
736
+
737
+ if bbox[3] - bbox[1] < min_size:
738
+ center_y = (bbox[1] + bbox[3]) / 2
739
+ bbox[1] = center_y - min_size / 2
740
+ bbox[3] = center_y + min_size / 2
741
+
742
+ # Clamp to image bounds
743
+ bbox[0] = max(0, min(bbox[0], self.img_width))
744
+ bbox[1] = max(0, min(bbox[1], self.img_height))
745
+ bbox[2] = max(0, min(bbox[2], self.img_width))
746
+ bbox[3] = max(0, min(bbox[3], self.img_height))
747
+
748
+
749
+ # ============================================================================
750
+ # MAIN NORMALIZATION ENGINE
751
+ # ============================================================================
752
+ class LayoutNormalizer:
753
+ """Main engine for normalizing wireframe layout."""
754
+
755
+ def __init__(self, elements: List[Element], img_width: float, img_height: float):
756
+ self.elements = elements
757
+ self.img_width = img_width
758
+ self.img_height = img_height
759
+ self.grid = GridLayoutSystem(img_width, img_height)
760
+ self.alignment_detector = AlignmentDetector(elements)
761
+ self.size_normalizer = SizeNormalizer(elements, img_width, img_height)
762
+
763
+ def normalize_layout(self) -> List[NormalizedElement]:
764
+ """Normalize all elements with proper sizing and alignment."""
765
+ print("\nπŸ”§ Starting layout normalization...")
766
+
767
+ # Step 1: Detect alignments
768
+ h_alignments = self.alignment_detector.detect_horizontal_alignments()
769
+ v_alignments = self.alignment_detector.detect_vertical_alignments()
770
+ edge_alignments = self.alignment_detector.detect_edge_alignments()
771
+
772
+ print(f"βœ“ Found {len(h_alignments)} horizontal alignment groups")
773
+ print(f"βœ“ Found {len(v_alignments)} vertical alignment groups")
774
+
775
+ # Step 2: Cluster sizes by type
776
+ size_clusters = self.size_normalizer.cluster_sizes_by_type()
777
+ print(f"βœ“ Created size clusters for {len(size_clusters)} element types")
778
+
779
+ # Step 3: Create element-to-cluster mapping
780
+ element_to_cluster = {}
781
+ element_to_size_category = {}
782
+ for label, clusters in size_clusters.items():
783
+ for i, cluster in enumerate(clusters):
784
+ category = f"{label}_size_{i + 1}"
785
+ for elem in cluster:
786
+ element_to_cluster[id(elem)] = cluster
787
+ element_to_size_category[id(elem)] = category
788
+
789
+ # Step 4: Normalize each element
790
+ normalized_elements = []
791
+
792
+ for elem in self.elements:
793
+ # Get size cluster
794
+ cluster = element_to_cluster.get(id(elem), [elem])
795
+ size_category = element_to_size_category.get(id(elem), f"{elem.label}_default")
796
+
797
+ # Get normalized size
798
+ norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster)
799
+
800
+ # Create normalized bbox (centered on original)
801
+ center_x, center_y = elem.center_x, elem.center_y
802
+ norm_bbox = [
803
+ center_x - norm_width / 2,
804
+ center_y - norm_height / 2,
805
+ center_x + norm_width / 2,
806
+ center_y + norm_height / 2
807
+ ]
808
+
809
+ # Snap to grid - preserve original size better
810
+ snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True)
811
+ grid_position = self.grid.get_grid_position(snapped_bbox)
812
+
813
+ normalized_elements.append(NormalizedElement(
814
+ original=elem,
815
+ normalized_bbox=snapped_bbox,
816
+ grid_position=grid_position,
817
+ size_category=size_category
818
+ ))
819
+
820
+ # Step 5: Apply alignment corrections
821
+ normalized_elements = self._apply_alignment_corrections(
822
+ normalized_elements, h_alignments, v_alignments, edge_alignments
823
+ )
824
+
825
+ # Step 6: Resolve overlaps
826
+ overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height)
827
+ normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements)
828
+
829
+ print(f"βœ… Normalized {len(normalized_elements)} elements")
830
+ return normalized_elements
831
+
832
+ def _apply_alignment_corrections(self, normalized_elements: List[NormalizedElement],
833
+ h_alignments: List[List[Element]],
834
+ v_alignments: List[List[Element]],
835
+ edge_alignments: Dict) -> List[NormalizedElement]:
836
+ """Apply alignment corrections to normalized elements."""
837
+
838
+ # Create lookup dictionary
839
+ elem_to_normalized = {id(ne.original): ne for ne in normalized_elements}
840
+
841
+ # Align horizontally grouped elements
842
+ for h_group in h_alignments:
843
+ norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized]
844
+ if len(norm_group) > 1:
845
+ # Align to average Y position
846
+ avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group)
847
+ for ne in norm_group:
848
+ height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
849
+ ne.normalized_bbox[1] = avg_y - height / 2
850
+ ne.normalized_bbox[3] = avg_y + height / 2
851
+
852
+ # Align vertically grouped elements
853
+ for v_group in v_alignments:
854
+ norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized]
855
+ if len(norm_group) > 1:
856
+ # Align to average X position
857
+ avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group)
858
+ for ne in norm_group:
859
+ width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
860
+ ne.normalized_bbox[0] = avg_x - width / 2
861
+ ne.normalized_bbox[2] = avg_x + width / 2
862
+
863
+ # Align edges
864
+ for edge_type, groups in edge_alignments.items():
865
+ for edge_group in groups:
866
+ norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized]
867
+ if len(norm_group) > 1:
868
+ if edge_type == 'left':
869
+ avg_left = sum(ne.normalized_bbox[0] for ne in norm_group) / len(norm_group)
870
+ for ne in norm_group:
871
+ width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
872
+ ne.normalized_bbox[0] = avg_left
873
+ ne.normalized_bbox[2] = avg_left + width
874
+ elif edge_type == 'right':
875
+ avg_right = sum(ne.normalized_bbox[2] for ne in norm_group) / len(norm_group)
876
+ for ne in norm_group:
877
+ width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
878
+ ne.normalized_bbox[2] = avg_right
879
+ ne.normalized_bbox[0] = avg_right - width
880
+ elif edge_type == 'top':
881
+ avg_top = sum(ne.normalized_bbox[1] for ne in norm_group) / len(norm_group)
882
+ for ne in norm_group:
883
+ height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
884
+ ne.normalized_bbox[1] = avg_top
885
+ ne.normalized_bbox[3] = avg_top + height
886
+ elif edge_type == 'bottom':
887
+ avg_bottom = sum(ne.normalized_bbox[3] for ne in norm_group) / len(norm_group)
888
+ for ne in norm_group:
889
+ height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
890
+ ne.normalized_bbox[3] = avg_bottom
891
+ ne.normalized_bbox[1] = avg_bottom - height
892
+
893
+ return normalized_elements
894
+
895
+
896
+ # ============================================================================
897
+ # VISUALIZATION & EXPORT
898
+ # ============================================================================
899
+ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
900
+ normalized_elements: List[NormalizedElement],
901
+ grid_system: GridLayoutSystem):
902
+ """Visualize original vs normalized layout."""
903
+
904
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
905
+
906
+ # Original detections
907
+ ax1.imshow(pil_img)
908
+ ax1.set_title("Original Predictions", fontsize=16, weight='bold')
909
+ ax1.axis('off')
910
+
911
+ for elem in elements:
912
+ x1, y1, x2, y2 = elem.bbox
913
+ rect = patches.Rectangle(
914
+ (x1, y1), x2 - x1, y2 - y1,
915
+ linewidth=2, edgecolor='red', facecolor='none'
916
+ )
917
+ ax1.add_patch(rect)
918
+ ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8,
919
+ bbox=dict(facecolor='white', alpha=0.7))
920
+
921
+ # Normalized layout
922
+ ax2.imshow(pil_img)
923
+ ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold')
924
+ ax2.axis('off')
925
+
926
+ # Draw grid
927
+ for x in range(grid_system.num_columns + 1):
928
+ x_pos = x * grid_system.cell_width
929
+ ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
930
+ for y in range(grid_system.num_rows + 1):
931
+ y_pos = y * grid_system.cell_height
932
+ ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
933
+
934
+ # Draw normalized elements
935
+ np.random.seed(42)
936
+ colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))
937
+ color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)}
938
+
939
+ for norm_elem in normalized_elements:
940
+ x1, y1, x2, y2 = norm_elem.normalized_bbox
941
+ color = color_map[norm_elem.original.label]
942
+
943
+ # Normalized box (thick)
944
+ rect = patches.Rectangle(
945
+ (x1, y1), x2 - x1, y2 - y1,
946
+ linewidth=3, edgecolor=color, facecolor='none'
947
+ )
948
+ ax2.add_patch(rect)
949
+
950
+ # Original box (thin, dashed)
951
+ ox1, oy1, ox2, oy2 = norm_elem.original.bbox
952
+ orig_rect = patches.Rectangle(
953
+ (ox1, oy1), ox2 - ox1, oy2 - oy1,
954
+ linewidth=1, edgecolor='gray', facecolor='none',
955
+ linestyle='--', alpha=0.5
956
+ )
957
+ ax2.add_patch(orig_rect)
958
+
959
+ # Label
960
+ grid_pos = norm_elem.grid_position
961
+ label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}"
962
+ ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7,
963
+ bbox=dict(facecolor=color, alpha=0.8, pad=2))
964
+
965
+ plt.tight_layout()
966
+ plt.show()
967
+
968
+
969
+ def export_to_json(normalized_elements: List[NormalizedElement],
970
+ grid_system: GridLayoutSystem,
971
+ output_path: str):
972
+ """Export normalized layout to JSON."""
973
+
974
+ output = {
975
+ 'metadata': {
976
+ 'image_width': grid_system.img_width,
977
+ 'image_height': grid_system.img_height,
978
+ 'grid_system': {
979
+ 'columns': grid_system.num_columns,
980
+ 'rows': grid_system.num_rows,
981
+ 'cell_width': round(grid_system.cell_width, 2),
982
+ 'cell_height': round(grid_system.cell_height, 2)
983
+ },
984
+ 'total_elements': len(normalized_elements)
985
+ },
986
+ 'elements': []
987
+ }
988
+
989
+ for i, norm_elem in enumerate(normalized_elements):
990
+ orig = norm_elem.original
991
+ norm_bbox = norm_elem.normalized_bbox
992
+
993
+ element_data = {
994
+ 'id': i,
995
+ 'type': orig.label,
996
+ 'confidence': round(orig.score, 3),
997
+ 'size_category': norm_elem.size_category,
998
+ 'original_bbox': {
999
+ 'x1': round(orig.bbox[0], 2),
1000
+ 'y1': round(orig.bbox[1], 2),
1001
+ 'x2': round(orig.bbox[2], 2),
1002
+ 'y2': round(orig.bbox[3], 2),
1003
+ 'width': round(orig.width, 2),
1004
+ 'height': round(orig.height, 2)
1005
+ },
1006
+ 'normalized_bbox': {
1007
+ 'x1': round(norm_bbox[0], 2),
1008
+ 'y1': round(norm_bbox[1], 2),
1009
+ 'x2': round(norm_bbox[2], 2),
1010
+ 'y2': round(norm_bbox[3], 2),
1011
+ 'width': round(norm_bbox[2] - norm_bbox[0], 2),
1012
+ 'height': round(norm_bbox[3] - norm_bbox[1], 2)
1013
+ },
1014
+ 'grid_position': norm_elem.grid_position,
1015
+ 'percentage': {
1016
+ 'x1': round((norm_bbox[0] / grid_system.img_width) * 100, 2),
1017
+ 'y1': round((norm_bbox[1] / grid_system.img_height) * 100, 2),
1018
+ 'x2': round((norm_bbox[2] / grid_system.img_width) * 100, 2),
1019
+ 'y2': round((norm_bbox[3] / grid_system.img_height) * 100, 2)
1020
+ }
1021
+ }
1022
+
1023
+ output['elements'].append(element_data)
1024
+
1025
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
1026
+ with open(output_path, 'w') as f:
1027
+ json.dump(output, f, indent=2)
1028
+
1029
+ print(f"\nβœ… Exported normalized layout to: {output_path}")
1030
+
1031
+
1032
+ def export_to_html(normalized_elements: List[NormalizedElement],
1033
+ grid_system: GridLayoutSystem,
1034
+ output_path: str):
1035
+ """Export normalized layout as responsive HTML/CSS."""
1036
+
1037
+ html_template = """<!DOCTYPE html>
1038
+ <html lang="en">
1039
+ <head>
1040
+ <meta charset="UTF-8">
1041
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
1042
+ <title>Wireframe Layout</title>
1043
+ <style>
1044
+ * {{
1045
+ margin: 0;
1046
+ padding: 0;
1047
+ box-sizing: border-box;
1048
+ }}
1049
+
1050
+ body {{
1051
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
1052
+ background: #f5f5f5;
1053
+ padding: 20px;
1054
+ }}
1055
+
1056
+ .container {{
1057
+ max-width: {img_width}px;
1058
+ margin: 0 auto;
1059
+ background: white;
1060
+ position: relative;
1061
+ height: {img_height}px;
1062
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
1063
+ }}
1064
+
1065
+ .element {{
1066
+ position: absolute;
1067
+ border: 2px solid #333;
1068
+ display: flex;
1069
+ align-items: center;
1070
+ justify-content: center;
1071
+ font-size: 12px;
1072
+ color: #666;
1073
+ background: rgba(255,255,255,0.9);
1074
+ transition: all 0.3s ease;
1075
+ }}
1076
+
1077
+ .element:hover {{
1078
+ z-index: 100;
1079
+ box-shadow: 0 4px 12px rgba(0,0,0,0.2);
1080
+ transform: scale(1.02);
1081
+ }}
1082
+
1083
+ .element-label {{
1084
+ font-weight: bold;
1085
+ font-size: 10px;
1086
+ text-transform: uppercase;
1087
+ }}
1088
+
1089
+ /* Element type specific styles */
1090
+ .button {{
1091
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1092
+ color: white;
1093
+ border-radius: 6px;
1094
+ font-weight: bold;
1095
+ cursor: pointer;
1096
+ }}
1097
+
1098
+ .checkbox {{
1099
+ background: white;
1100
+ border: 2px solid #4a5568;
1101
+ border-radius: 4px;
1102
+ }}
1103
+
1104
+ .textfield {{
1105
+ background: white;
1106
+ border: 2px solid #cbd5e0;
1107
+ border-radius: 4px;
1108
+ padding: 8px;
1109
+ }}
1110
+
1111
+ .text {{
1112
+ background: transparent;
1113
+ border: 1px dashed #cbd5e0;
1114
+ color: #2d3748;
1115
+ }}
1116
+
1117
+ .paragraph {{
1118
+ background: transparent;
1119
+ border: 1px dashed #cbd5e0;
1120
+ color: #4a5568;
1121
+ text-align: left;
1122
+ padding: 8px;
1123
+ }}
1124
+
1125
+ .image {{
1126
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
1127
+ color: white;
1128
+ border: none;
1129
+ }}
1130
+
1131
+ .navbar {{
1132
+ background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
1133
+ color: white;
1134
+ font-weight: bold;
1135
+ border: none;
1136
+ }}
1137
+
1138
+ .info-panel {{
1139
+ position: fixed;
1140
+ top: 20px;
1141
+ right: 20px;
1142
+ background: white;
1143
+ padding: 20px;
1144
+ border-radius: 8px;
1145
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
1146
+ max-width: 300px;
1147
+ }}
1148
+
1149
+ .info-panel h3 {{
1150
+ margin-bottom: 10px;
1151
+ color: #2d3748;
1152
+ }}
1153
+
1154
+ .info-panel p {{
1155
+ margin: 5px 0;
1156
+ font-size: 14px;
1157
+ color: #4a5568;
1158
+ }}
1159
+ </style>
1160
+ </head>
1161
+ <body>
1162
+ <div class="info-panel">
1163
+ <h3>πŸ“ Layout Info</h3>
1164
+ <p><strong>Grid:</strong> {grid_cols} Γ— {grid_rows}</p>
1165
+ <p><strong>Elements:</strong> {total_elements}</p>
1166
+ <p><strong>Dimensions:</strong> {img_width}px Γ— {img_height}px</p>
1167
+ <p style="margin-top: 15px; font-size: 12px; color: #718096;">
1168
+ Hover over elements to see details
1169
+ </p>
1170
+ </div>
1171
+
1172
+ <div class="container">
1173
+ {elements_html}
1174
+ </div>
1175
+ </body>
1176
+ </html>"""
1177
+
1178
+ elements_html = []
1179
+
1180
+ for i, norm_elem in enumerate(normalized_elements):
1181
+ x1, y1, x2, y2 = norm_elem.normalized_bbox
1182
+ width = x2 - x1
1183
+ height = y2 - y1
1184
+
1185
+ element_html = f"""
1186
+ <div class="element {norm_elem.original.label}"
1187
+ style="left: {x1}px; top: {y1}px; width: {width}px; height: {height}px;"
1188
+ title="{norm_elem.original.label} | Grid: R{norm_elem.grid_position['start_row']} C{norm_elem.grid_position['start_col']} | Size: {norm_elem.size_category}">
1189
+ <span class="element-label">{norm_elem.original.label}</span>
1190
+ </div>"""
1191
+
1192
+ elements_html.append(element_html)
1193
+
1194
+ html_content = html_template.format(
1195
+ img_width=int(grid_system.img_width),
1196
+ img_height=int(grid_system.img_height),
1197
+ grid_cols=grid_system.num_columns,
1198
+ grid_rows=grid_system.num_rows,
1199
+ total_elements=len(normalized_elements),
1200
+ elements_html='\n'.join(elements_html)
1201
+ )
1202
+
1203
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
1204
+ with open(output_path, 'w', encoding='utf-8') as f:
1205
+ f.write(html_content)
1206
+
1207
+ print(f"βœ… Exported HTML layout to: {output_path}")
1208
+
1209
+
1210
+ # ============================================================================
1211
+ # MAIN PIPELINE
1212
+ # ============================================================================
1213
+ def process_wireframe(image_path: str,
1214
+ save_json: bool = True,
1215
+ save_html: bool = True,
1216
+ show_visualization: bool = True) -> Dict:
1217
+ """
1218
+ Complete pipeline to process wireframe image.
1219
+
1220
+ Args:
1221
+ image_path: Path to wireframe image
1222
+ save_json: Export normalized layout as JSON
1223
+ save_html: Export normalized layout as HTML
1224
+ show_visualization: Display matplotlib comparison
1225
+
1226
+ Returns:
1227
+ Dictionary containing all processing results
1228
+ """
1229
+
1230
+ print("=" * 80)
1231
+ print("πŸš€ WIREFRAME LAYOUT NORMALIZER")
1232
+ print("=" * 80)
1233
+
1234
+ # Step 1: Load model and get predictions
1235
+ global model
1236
+ if model is None:
1237
+ print("\nπŸ“¦ Loading model...")
1238
+ try:
1239
+ model = tf.keras.models.load_model(
1240
+ MODEL_PATH,
1241
+ custom_objects={'LossCalculation': LossCalculation}
1242
+ )
1243
+ print("βœ… Model loaded successfully!")
1244
+ except Exception as e:
1245
+ print(f"❌ Error loading model: {e}")
1246
+ print("\nTrying alternative loading method...")
1247
+ try:
1248
+ model = tf.keras.models.load_model(MODEL_PATH, compile=False)
1249
+ print("βœ… Model loaded successfully (without compilation)!")
1250
+ except Exception as e2:
1251
+ print(f"❌ Failed to load model: {e2}")
1252
+ return {}
1253
+
1254
+ print(f"\nπŸ“Έ Processing image: {image_path}")
1255
+ pil_img, elements = get_predictions(image_path)
1256
+ print(f"βœ… Detected {len(elements)} elements")
1257
+
1258
+ if not elements:
1259
+ print("⚠️ No elements detected. Exiting.")
1260
+ return {}
1261
+
1262
+ # Step 2: Normalize layout
1263
+ normalizer = LayoutNormalizer(elements, pil_img.width, pil_img.height)
1264
+ normalized_elements = normalizer.normalize_layout()
1265
+
1266
+ # Step 3: Generate outputs
1267
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
1268
+ base_filename = os.path.splitext(os.path.basename(image_path))[0]
1269
+
1270
+ results = {
1271
+ 'image': pil_img,
1272
+ 'original_elements': elements,
1273
+ 'normalized_elements': normalized_elements,
1274
+ 'grid_system': normalizer.grid
1275
+ }
1276
+
1277
+ # Export JSON
1278
+ if save_json:
1279
+ json_path = os.path.join(OUTPUT_DIR, f"{base_filename}_normalized.json")
1280
+ export_to_json(normalized_elements, normalizer.grid, json_path)
1281
+ results['json_path'] = json_path
1282
+
1283
+ # Export HTML
1284
+ if save_html:
1285
+ html_path = os.path.join(OUTPUT_DIR, f"{base_filename}_layout.html")
1286
+ export_to_html(normalized_elements, normalizer.grid, html_path)
1287
+ results['html_path'] = html_path
1288
+
1289
+ # Visualize
1290
+ if show_visualization:
1291
+ print("\n🎨 Generating visualization...")
1292
+ visualize_comparison(pil_img, elements, normalized_elements, normalizer.grid)
1293
+
1294
+ # Print summary
1295
+ print("\n" + "=" * 80)
1296
+ print("πŸ“Š PROCESSING SUMMARY")
1297
+ print("=" * 80)
1298
+
1299
+ # Count by type
1300
+ type_counts = {}
1301
+ for elem in elements:
1302
+ type_counts[elem.label] = type_counts.get(elem.label, 0) + 1
1303
+
1304
+ print(f"\nπŸ“¦ Element Types:")
1305
+ for elem_type, count in sorted(type_counts.items()):
1306
+ print(f" β€’ {elem_type}: {count}")
1307
+
1308
+ # Size categories
1309
+ size_categories = {}
1310
+ for norm_elem in normalized_elements:
1311
+ size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1
1312
+
1313
+ print(f"\nπŸ“ Size Categories: {len(size_categories)}")
1314
+
1315
+ # Alignment info
1316
+ h_alignments = normalizer.alignment_detector.detect_horizontal_alignments()
1317
+ v_alignments = normalizer.alignment_detector.detect_vertical_alignments()
1318
+
1319
+ print(f"\nπŸ“ Alignment:")
1320
+ print(f" β€’ Horizontal groups: {len(h_alignments)}")
1321
+ print(f" β€’ Vertical groups: {len(v_alignments)}")
1322
+
1323
+ print("\n" + "=" * 80)
1324
+ print("βœ… PROCESSING COMPLETE!")
1325
+ print("=" * 80 + "\n")
1326
+
1327
+ return results
1328
+
1329
+
1330
+ def batch_process(image_dir: str, pattern: str = "*.png"):
1331
+ """Process multiple wireframe images in a directory."""
1332
+ import glob
1333
+
1334
+ image_paths = glob.glob(os.path.join(image_dir, pattern))
1335
+
1336
+ if not image_paths:
1337
+ print(f"❌ No images found matching pattern: {pattern}")
1338
+ return
1339
+
1340
+ print(f"πŸ“‚ Found {len(image_paths)} images to process\n")
1341
+
1342
+ all_results = []
1343
+ for i, image_path in enumerate(image_paths, 1):
1344
+ print(f"\n{'=' * 80}")
1345
+ print(f"Processing image {i}/{len(image_paths)}: {os.path.basename(image_path)}")
1346
+ print(f"{'=' * 80}")
1347
+
1348
+ try:
1349
+ results = process_wireframe(
1350
+ image_path,
1351
+ save_json=True,
1352
+ save_html=True,
1353
+ show_visualization=False
1354
+ )
1355
+ all_results.append({
1356
+ 'image_path': image_path,
1357
+ 'success': True,
1358
+ 'results': results
1359
+ })
1360
+ except Exception as e:
1361
+ print(f"❌ Error processing {image_path}: {str(e)}")
1362
+ all_results.append({
1363
+ 'image_path': image_path,
1364
+ 'success': False,
1365
+ 'error': str(e)
1366
+ })
1367
+
1368
+ # Summary
1369
+ successful = sum(1 for r in all_results if r['success'])
1370
+ print(f"\n{'=' * 80}")
1371
+ print(f"πŸ“Š BATCH PROCESSING COMPLETE")
1372
+ print(f"{'=' * 80}")
1373
+ print(f"βœ… Successful: {successful}/{len(image_paths)}")
1374
+ print(f"❌ Failed: {len(image_paths) - successful}/{len(image_paths)}")
1375
+
1376
+ return all_results
1377
+
1378
+
1379
+ # ============================================================================
1380
+ # EXAMPLE USAGE
1381
+ # ============================================================================
1382
+ if __name__ == "__main__":
1383
+ # Single image processing
1384
+ image_path = "./image/6LHls1vE.jpg"
1385
+
1386
+ # Process with all outputs
1387
+ results = process_wireframe(
1388
+ image_path,
1389
+ save_json=True,
1390
+ save_html=True,
1391
+ show_visualization=True
1392
+ )
1393
+
1394
+ # Or batch process multiple images
1395
+ # batch_results = batch_process("./wireframes/", pattern="*.png")