Spaces:
Sleeping
Sleeping
Commit Β·
e9e7a8d
1
Parent(s): 9c67a05
Updated detection module script and requirements.txt for ONNX models than Keras
Browse files- requirements.txt +6 -6
- visualization.py +159 -216
- wireframe_detection_model_best_700.keras β wireframe.onnx +2 -2
requirements.txt
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
gradio
|
| 3 |
-
fastapi
|
| 4 |
-
uvicorn
|
| 5 |
numpy
|
| 6 |
pillow
|
| 7 |
-
matplotlib
|
| 8 |
opencv-python-headless
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime==1.19.0
|
|
|
|
|
|
|
|
|
|
| 2 |
numpy
|
| 3 |
pillow
|
|
|
|
| 4 |
opencv-python-headless
|
| 5 |
+
fastapi
|
| 6 |
+
uvicorn
|
| 7 |
+
gradio
|
| 8 |
+
python-multipart
|
| 9 |
+
matplotlib
|
visualization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import matplotlib.patches as patches
|
|
@@ -11,76 +11,9 @@ from typing import List, Tuple, Dict, Optional
|
|
| 11 |
|
| 12 |
|
| 13 |
# ============================================================================
|
| 14 |
-
#
|
| 15 |
# ============================================================================
|
| 16 |
-
|
| 17 |
-
class LossCalculation(tf.keras.losses.Loss):
|
| 18 |
-
"""Custom loss function for wireframe detection."""
|
| 19 |
-
|
| 20 |
-
def __init__(self, num_classes=7, lambda_coord=5.0, lambda_noobj=0.5,
|
| 21 |
-
name='loss_calculation', reduction='sum_over_batch_size', **kwargs):
|
| 22 |
-
super().__init__(name=name, reduction=reduction)
|
| 23 |
-
self.num_classes = num_classes
|
| 24 |
-
self.lambda_coord = lambda_coord
|
| 25 |
-
self.lambda_noobj = lambda_noobj
|
| 26 |
-
|
| 27 |
-
def call(self, y_true, y_pred):
|
| 28 |
-
obj_true = y_true[..., 0]
|
| 29 |
-
box_true = y_true[..., 1:5]
|
| 30 |
-
cls_true = y_true[..., 5:]
|
| 31 |
-
|
| 32 |
-
obj_pred_logits = y_pred[..., 0]
|
| 33 |
-
box_pred = y_pred[..., 1:5]
|
| 34 |
-
cls_pred_logits = y_pred[..., 5:]
|
| 35 |
-
|
| 36 |
-
obj_mask = tf.cast(obj_true > 0.5, tf.float32)
|
| 37 |
-
noobj_mask = 1.0 - obj_mask
|
| 38 |
-
num_pos = tf.maximum(tf.reduce_sum(obj_mask), 1.0)
|
| 39 |
-
|
| 40 |
-
obj_loss_pos = obj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
|
| 41 |
-
labels=obj_true, logits=obj_pred_logits)
|
| 42 |
-
obj_loss_neg = noobj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
|
| 43 |
-
labels=obj_true, logits=obj_pred_logits)
|
| 44 |
-
obj_loss = (tf.reduce_sum(obj_loss_pos) + self.lambda_noobj * tf.reduce_sum(obj_loss_neg)) / tf.cast(
|
| 45 |
-
tf.size(obj_true), tf.float32)
|
| 46 |
-
|
| 47 |
-
xy_pred = tf.nn.sigmoid(box_pred[..., 0:2])
|
| 48 |
-
wh_pred = tf.nn.sigmoid(box_pred[..., 2:4])
|
| 49 |
-
xy_true = box_true[..., 0:2]
|
| 50 |
-
wh_true = box_true[..., 2:4]
|
| 51 |
-
|
| 52 |
-
xy_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(xy_true - xy_pred)) / num_pos
|
| 53 |
-
wh_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(wh_true - wh_pred)) / num_pos
|
| 54 |
-
box_loss = self.lambda_coord * (xy_loss + wh_loss)
|
| 55 |
-
|
| 56 |
-
cls_loss = tf.reduce_sum(obj_mask * tf.nn.softmax_cross_entropy_with_logits(
|
| 57 |
-
labels=cls_true, logits=cls_pred_logits)) / num_pos
|
| 58 |
-
|
| 59 |
-
total_loss = obj_loss + box_loss + cls_loss
|
| 60 |
-
return tf.clip_by_value(total_loss, 0.0, 100.0)
|
| 61 |
-
|
| 62 |
-
def _smooth_l1_loss(self, x, beta=1.0):
|
| 63 |
-
abs_x = tf.abs(x)
|
| 64 |
-
return tf.where(abs_x < beta, 0.5 * x * x / beta, abs_x - 0.5 * beta)
|
| 65 |
-
|
| 66 |
-
def get_config(self):
|
| 67 |
-
config = super().get_config()
|
| 68 |
-
config.update({
|
| 69 |
-
'num_classes': self.num_classes,
|
| 70 |
-
'lambda_coord': self.lambda_coord,
|
| 71 |
-
'lambda_noobj': self.lambda_noobj,
|
| 72 |
-
})
|
| 73 |
-
return config
|
| 74 |
-
|
| 75 |
-
@classmethod
|
| 76 |
-
def from_config(cls, config):
|
| 77 |
-
return cls(**config)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
# ============================================================================
|
| 81 |
-
# CONFIGURATION - UPDATED FOR BETTER PRECISION
|
| 82 |
-
# ============================================================================
|
| 83 |
-
MODEL_PATH = "./wireframe_detection_model_best_700.keras"
|
| 84 |
OUTPUT_DIR = "./output/"
|
| 85 |
CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"]
|
| 86 |
|
|
@@ -88,27 +21,102 @@ IMG_SIZE = 416
|
|
| 88 |
CONF_THRESHOLD = 0.1
|
| 89 |
IOU_THRESHOLD = 0.1
|
| 90 |
|
| 91 |
-
# Layout Configuration
|
| 92 |
-
GRID_COLUMNS = 24
|
| 93 |
-
ALIGNMENT_THRESHOLD = 10
|
| 94 |
-
SIZE_CLUSTERING_THRESHOLD = 15
|
| 95 |
|
| 96 |
-
# Standard sizes for each element type (relative units)
|
| 97 |
STANDARD_SIZES = {
|
| 98 |
-
'button': {'width': 2, 'height': 1},
|
| 99 |
-
'checkbox': {'width': 1, 'height': 1},
|
| 100 |
-
'textfield': {'width': 5, 'height': 1},
|
| 101 |
-
'text': {'width': 3, 'height': 1},
|
| 102 |
-
'paragraph': {'width': 8, 'height': 2},
|
| 103 |
-
'image': {'width': 4, 'height': 4},
|
| 104 |
-
'navbar': {'width': 24, 'height': 1}
|
| 105 |
}
|
| 106 |
|
| 107 |
-
|
| 108 |
|
| 109 |
|
| 110 |
# ============================================================================
|
| 111 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# ============================================================================
|
| 113 |
@dataclass
|
| 114 |
class Element:
|
|
@@ -139,13 +147,13 @@ class NormalizedElement:
|
|
| 139 |
|
| 140 |
|
| 141 |
# ============================================================================
|
| 142 |
-
# PREDICTION EXTRACTION
|
| 143 |
# ============================================================================
|
| 144 |
def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
|
| 145 |
-
"""Extract predictions from the model."""
|
| 146 |
-
global
|
| 147 |
-
if
|
| 148 |
-
raise ValueError("
|
| 149 |
|
| 150 |
# Load and preprocess image
|
| 151 |
pil_img = Image.open(image_path).convert("RGB")
|
|
@@ -155,25 +163,28 @@ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
|
|
| 155 |
img_array = np.array(resized_img, dtype=np.float32) / 255.0
|
| 156 |
input_tensor = np.expand_dims(img_array, axis=0)
|
| 157 |
|
| 158 |
-
# Get predictions
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
| 160 |
raw_boxes = []
|
| 161 |
S = pred_grid.shape[0]
|
| 162 |
cell_size = 1.0 / S
|
| 163 |
|
| 164 |
for row in range(S):
|
| 165 |
for col in range(S):
|
| 166 |
-
obj_score = float(
|
| 167 |
if obj_score < CONF_THRESHOLD:
|
| 168 |
continue
|
| 169 |
|
| 170 |
-
x_offset = float(
|
| 171 |
-
y_offset = float(
|
| 172 |
-
width = float(
|
| 173 |
-
height = float(
|
| 174 |
|
| 175 |
class_logits = pred_grid[row, col, 5:]
|
| 176 |
-
class_probs =
|
| 177 |
class_id = int(np.argmax(class_probs))
|
| 178 |
class_conf = float(class_probs[class_id])
|
| 179 |
final_score = obj_score * class_conf
|
|
@@ -191,25 +202,24 @@ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
|
|
| 191 |
if x2 > x1 and y2 > y1:
|
| 192 |
raw_boxes.append((class_id, final_score, x1, y1, x2, y2))
|
| 193 |
|
| 194 |
-
# Apply NMS per class
|
| 195 |
elements = []
|
| 196 |
for class_id in range(len(CLASS_NAMES)):
|
| 197 |
class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id]
|
| 198 |
if not class_boxes:
|
| 199 |
continue
|
| 200 |
|
| 201 |
-
scores = [b[0] for b in class_boxes]
|
| 202 |
-
boxes_xyxy = [[b[1], b[2], b[3], b[4]] for b in class_boxes]
|
| 203 |
|
| 204 |
-
selected_indices =
|
| 205 |
boxes=boxes_xyxy,
|
| 206 |
scores=scores,
|
| 207 |
-
max_output_size=50,
|
| 208 |
iou_threshold=IOU_THRESHOLD,
|
| 209 |
score_threshold=CONF_THRESHOLD
|
| 210 |
)
|
| 211 |
|
| 212 |
-
for idx in selected_indices
|
| 213 |
score, x1, y1, x2, y2 = class_boxes[idx]
|
| 214 |
elements.append(Element(
|
| 215 |
label=CLASS_NAMES[class_id],
|
|
@@ -221,7 +231,7 @@ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
|
|
| 221 |
|
| 222 |
|
| 223 |
# ============================================================================
|
| 224 |
-
# ALIGNMENT DETECTION
|
| 225 |
# ============================================================================
|
| 226 |
class AlignmentDetector:
|
| 227 |
"""Detects alignment relationships between elements."""
|
|
@@ -333,7 +343,7 @@ class AlignmentDetector:
|
|
| 333 |
|
| 334 |
|
| 335 |
# ============================================================================
|
| 336 |
-
# SIZE NORMALIZATION
|
| 337 |
# ============================================================================
|
| 338 |
class SizeNormalizer:
|
| 339 |
"""Normalizes element sizes based on type and clustering."""
|
|
@@ -384,17 +394,13 @@ class SizeNormalizer:
|
|
| 384 |
return clusters
|
| 385 |
|
| 386 |
def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]:
|
| 387 |
-
"""Get normalized size for an element based on its cluster
|
| 388 |
-
# Use the actual detected size instead of aggressive averaging
|
| 389 |
-
# Only normalize if there's a significant cluster
|
| 390 |
if len(size_cluster) >= 3:
|
| 391 |
-
# Use median instead of mean to avoid outliers
|
| 392 |
widths = sorted([e.width for e in size_cluster])
|
| 393 |
heights = sorted([e.height for e in size_cluster])
|
| 394 |
median_width = widths[len(widths) // 2]
|
| 395 |
median_height = heights[len(heights) // 2]
|
| 396 |
|
| 397 |
-
# Only normalize if element is within 30% of median
|
| 398 |
if abs(element.width - median_width) / median_width < 0.3:
|
| 399 |
normalized_width = round(median_width)
|
| 400 |
else:
|
|
@@ -405,7 +411,6 @@ class SizeNormalizer:
|
|
| 405 |
else:
|
| 406 |
normalized_height = round(element.height)
|
| 407 |
else:
|
| 408 |
-
# Small cluster - keep original size
|
| 409 |
normalized_width = round(element.width)
|
| 410 |
normalized_height = round(element.height)
|
| 411 |
|
|
@@ -413,7 +418,7 @@ class SizeNormalizer:
|
|
| 413 |
|
| 414 |
|
| 415 |
# ============================================================================
|
| 416 |
-
# GRID-BASED LAYOUT SYSTEM
|
| 417 |
# ============================================================================
|
| 418 |
class GridLayoutSystem:
|
| 419 |
"""Grid-based layout system for precise positioning."""
|
|
@@ -432,44 +437,36 @@ class GridLayoutSystem:
|
|
| 432 |
print(f"π Cell size: {self.cell_width:.1f}px Γ {self.cell_height:.1f}px")
|
| 433 |
|
| 434 |
def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]:
|
| 435 |
-
"""Snap bounding box to grid
|
| 436 |
x1, y1, x2, y2 = bbox
|
| 437 |
original_width = x2 - x1
|
| 438 |
original_height = y2 - y1
|
| 439 |
|
| 440 |
-
# Calculate center
|
| 441 |
center_x = (x1 + x2) / 2
|
| 442 |
center_y = (y1 + y2) / 2
|
| 443 |
|
| 444 |
-
# Find nearest grid cell for center
|
| 445 |
center_col = round(center_x / self.cell_width)
|
| 446 |
center_row = round(center_y / self.cell_height)
|
| 447 |
|
| 448 |
if preserve_size:
|
| 449 |
-
# Calculate span based on actual size (don't force to standard)
|
| 450 |
width_cells = max(1, round(original_width / self.cell_width))
|
| 451 |
height_cells = max(1, round(original_height / self.cell_height))
|
| 452 |
else:
|
| 453 |
-
# Use standard size
|
| 454 |
standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1})
|
| 455 |
width_cells = max(1, round(original_width / self.cell_width))
|
| 456 |
height_cells = max(1, round(original_height / self.cell_height))
|
| 457 |
|
| 458 |
-
# Only adjust to standard if very close
|
| 459 |
if abs(width_cells - standard['width']) <= 0.5:
|
| 460 |
width_cells = standard['width']
|
| 461 |
if abs(height_cells - standard['height']) <= 0.5:
|
| 462 |
height_cells = standard['height']
|
| 463 |
|
| 464 |
-
# Calculate start position (center the element)
|
| 465 |
start_col = center_col - width_cells // 2
|
| 466 |
start_row = center_row - height_cells // 2
|
| 467 |
|
| 468 |
-
# Clamp to grid bounds
|
| 469 |
start_col = max(0, min(start_col, self.num_columns - width_cells))
|
| 470 |
start_row = max(0, min(start_row, self.num_rows - height_cells))
|
| 471 |
|
| 472 |
-
# Convert back to pixels
|
| 473 |
snapped_x1 = start_col * self.cell_width
|
| 474 |
snapped_y1 = start_row * self.cell_height
|
| 475 |
snapped_x2 = (start_col + width_cells) * self.cell_width
|
|
@@ -497,7 +494,7 @@ class GridLayoutSystem:
|
|
| 497 |
|
| 498 |
|
| 499 |
# ============================================================================
|
| 500 |
-
# OVERLAP DETECTION & RESOLUTION
|
| 501 |
# ============================================================================
|
| 502 |
class OverlapResolver:
|
| 503 |
"""Detects and resolves overlapping elements."""
|
|
@@ -506,7 +503,7 @@ class OverlapResolver:
|
|
| 506 |
self.elements = elements
|
| 507 |
self.img_width = img_width
|
| 508 |
self.img_height = img_height
|
| 509 |
-
self.overlap_threshold = 0.2
|
| 510 |
|
| 511 |
def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
|
| 512 |
"""Compute Intersection over Union between two bounding boxes."""
|
|
@@ -545,7 +542,7 @@ class OverlapResolver:
|
|
| 545 |
return overlap_ratio1, overlap_ratio2
|
| 546 |
|
| 547 |
def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]:
|
| 548 |
-
"""Resolve overlaps by adjusting element positions
|
| 549 |
print("\nπ Checking for overlaps...")
|
| 550 |
|
| 551 |
overlaps = []
|
|
@@ -579,7 +576,6 @@ class OverlapResolver:
|
|
| 579 |
|
| 580 |
print(f"β οΈ Found {len(overlaps)} overlapping element pairs")
|
| 581 |
|
| 582 |
-
# Sort by overlap severity
|
| 583 |
overlaps.sort(key=lambda x: x['overlap'], reverse=True)
|
| 584 |
|
| 585 |
elements_to_remove = set()
|
|
@@ -595,7 +591,6 @@ class OverlapResolver:
|
|
| 595 |
elem2 = overlap_info['elem2']
|
| 596 |
overlap_ratio = overlap_info['overlap']
|
| 597 |
|
| 598 |
-
# Strategy 1: Nearly complete overlap (>70%) - remove lower confidence
|
| 599 |
if overlap_ratio > 0.7:
|
| 600 |
if elem1.original.score < elem2.original.score:
|
| 601 |
elements_to_remove.add(idx1)
|
|
@@ -606,13 +601,11 @@ class OverlapResolver:
|
|
| 606 |
print(f" ποΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - "
|
| 607 |
f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}")
|
| 608 |
|
| 609 |
-
# Strategy 2: Significant overlap (40-70%) - try to separate
|
| 610 |
elif overlap_ratio > 0.4:
|
| 611 |
self._try_separate_elements(elem1, elem2, overlap_info)
|
| 612 |
print(f" βοΈ Separating {elem1.original.label} and {elem2.original.label} "
|
| 613 |
f"(overlap: {overlap_ratio * 100:.1f}%)")
|
| 614 |
|
| 615 |
-
# Strategy 3: Moderate overlap (20-40%) - shrink slightly
|
| 616 |
else:
|
| 617 |
self._shrink_overlapping_edges(elem1, elem2, overlap_info)
|
| 618 |
print(f" π Shrinking {elem1.original.label} and {elem2.original.label} "
|
|
@@ -629,11 +622,10 @@ class OverlapResolver:
|
|
| 629 |
|
| 630 |
def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement,
|
| 631 |
overlap_info: Dict):
|
| 632 |
-
"""Try to separate two significantly overlapping elements
|
| 633 |
bbox1 = elem1.normalized_bbox
|
| 634 |
bbox2 = elem2.normalized_bbox
|
| 635 |
|
| 636 |
-
# Calculate overlap dimensions
|
| 637 |
overlap_x1 = max(bbox1[0], bbox2[0])
|
| 638 |
overlap_y1 = max(bbox1[1], bbox2[1])
|
| 639 |
overlap_x2 = min(bbox1[2], bbox2[2])
|
|
@@ -642,45 +634,35 @@ class OverlapResolver:
|
|
| 642 |
overlap_width = overlap_x2 - overlap_x1
|
| 643 |
overlap_height = overlap_y2 - overlap_y1
|
| 644 |
|
| 645 |
-
# Calculate centers
|
| 646 |
center1_x = (bbox1[0] + bbox1[2]) / 2
|
| 647 |
center1_y = (bbox1[1] + bbox1[3]) / 2
|
| 648 |
center2_x = (bbox2[0] + bbox2[2]) / 2
|
| 649 |
center2_y = (bbox2[1] + bbox2[3]) / 2
|
| 650 |
|
| 651 |
-
# Determine separation direction
|
| 652 |
dx = abs(center2_x - center1_x)
|
| 653 |
dy = abs(center2_y - center1_y)
|
| 654 |
|
| 655 |
-
|
| 656 |
-
min_gap = 3 # pixels
|
| 657 |
|
| 658 |
if dx > dy:
|
| 659 |
-
# Separate horizontally
|
| 660 |
if center1_x < center2_x:
|
| 661 |
-
# elem1 is left of elem2
|
| 662 |
midpoint = (bbox1[2] + bbox2[0]) / 2
|
| 663 |
bbox1[2] = midpoint - min_gap
|
| 664 |
bbox2[0] = midpoint + min_gap
|
| 665 |
else:
|
| 666 |
-
# elem2 is left of elem1
|
| 667 |
midpoint = (bbox2[2] + bbox1[0]) / 2
|
| 668 |
bbox2[2] = midpoint - min_gap
|
| 669 |
bbox1[0] = midpoint + min_gap
|
| 670 |
else:
|
| 671 |
-
# Separate vertically
|
| 672 |
if center1_y < center2_y:
|
| 673 |
-
# elem1 is above elem2
|
| 674 |
midpoint = (bbox1[3] + bbox2[1]) / 2
|
| 675 |
bbox1[3] = midpoint - min_gap
|
| 676 |
bbox2[1] = midpoint + min_gap
|
| 677 |
else:
|
| 678 |
-
# elem2 is above elem1
|
| 679 |
midpoint = (bbox2[3] + bbox1[1]) / 2
|
| 680 |
bbox2[3] = midpoint - min_gap
|
| 681 |
bbox1[1] = midpoint + min_gap
|
| 682 |
|
| 683 |
-
# Ensure boxes remain valid
|
| 684 |
self._ensure_valid_bbox(bbox1)
|
| 685 |
self._ensure_valid_bbox(bbox2)
|
| 686 |
|
|
@@ -690,7 +672,6 @@ class OverlapResolver:
|
|
| 690 |
bbox1 = elem1.normalized_bbox
|
| 691 |
bbox2 = elem2.normalized_bbox
|
| 692 |
|
| 693 |
-
# Calculate overlap region
|
| 694 |
overlap_x1 = max(bbox1[0], bbox2[0])
|
| 695 |
overlap_y1 = max(bbox1[1], bbox2[1])
|
| 696 |
overlap_x2 = min(bbox1[2], bbox2[2])
|
|
@@ -699,11 +680,9 @@ class OverlapResolver:
|
|
| 699 |
overlap_width = overlap_x2 - overlap_x1
|
| 700 |
overlap_height = overlap_y2 - overlap_y1
|
| 701 |
|
| 702 |
-
|
| 703 |
-
gap = 2 # pixels
|
| 704 |
|
| 705 |
if overlap_width > overlap_height:
|
| 706 |
-
# Horizontal overlap is larger
|
| 707 |
shrink = overlap_width / 2 + gap
|
| 708 |
if bbox1[0] < bbox2[0]:
|
| 709 |
bbox1[2] -= shrink
|
|
@@ -712,7 +691,6 @@ class OverlapResolver:
|
|
| 712 |
bbox2[2] -= shrink
|
| 713 |
bbox1[0] += shrink
|
| 714 |
else:
|
| 715 |
-
# Vertical overlap is larger
|
| 716 |
shrink = overlap_height / 2 + gap
|
| 717 |
if bbox1[1] < bbox2[1]:
|
| 718 |
bbox1[3] -= shrink
|
|
@@ -726,9 +704,8 @@ class OverlapResolver:
|
|
| 726 |
|
| 727 |
def _ensure_valid_bbox(self, bbox: List[float]):
|
| 728 |
"""Ensure bounding box has minimum size and is within image bounds."""
|
| 729 |
-
min_size = 8
|
| 730 |
|
| 731 |
-
# Ensure minimum size
|
| 732 |
if bbox[2] - bbox[0] < min_size:
|
| 733 |
center_x = (bbox[0] + bbox[2]) / 2
|
| 734 |
bbox[0] = center_x - min_size / 2
|
|
@@ -739,7 +716,6 @@ class OverlapResolver:
|
|
| 739 |
bbox[1] = center_y - min_size / 2
|
| 740 |
bbox[3] = center_y + min_size / 2
|
| 741 |
|
| 742 |
-
# Clamp to image bounds
|
| 743 |
bbox[0] = max(0, min(bbox[0], self.img_width))
|
| 744 |
bbox[1] = max(0, min(bbox[1], self.img_height))
|
| 745 |
bbox[2] = max(0, min(bbox[2], self.img_width))
|
|
@@ -747,7 +723,7 @@ class OverlapResolver:
|
|
| 747 |
|
| 748 |
|
| 749 |
# ============================================================================
|
| 750 |
-
# MAIN NORMALIZATION ENGINE
|
| 751 |
# ============================================================================
|
| 752 |
class LayoutNormalizer:
|
| 753 |
"""Main engine for normalizing wireframe layout."""
|
|
@@ -764,7 +740,6 @@ class LayoutNormalizer:
|
|
| 764 |
"""Normalize all elements with proper sizing and alignment."""
|
| 765 |
print("\nπ§ Starting layout normalization...")
|
| 766 |
|
| 767 |
-
# Step 1: Detect alignments
|
| 768 |
h_alignments = self.alignment_detector.detect_horizontal_alignments()
|
| 769 |
v_alignments = self.alignment_detector.detect_vertical_alignments()
|
| 770 |
edge_alignments = self.alignment_detector.detect_edge_alignments()
|
|
@@ -772,11 +747,9 @@ class LayoutNormalizer:
|
|
| 772 |
print(f"β Found {len(h_alignments)} horizontal alignment groups")
|
| 773 |
print(f"β Found {len(v_alignments)} vertical alignment groups")
|
| 774 |
|
| 775 |
-
# Step 2: Cluster sizes by type
|
| 776 |
size_clusters = self.size_normalizer.cluster_sizes_by_type()
|
| 777 |
print(f"β Created size clusters for {len(size_clusters)} element types")
|
| 778 |
|
| 779 |
-
# Step 3: Create element-to-cluster mapping
|
| 780 |
element_to_cluster = {}
|
| 781 |
element_to_size_category = {}
|
| 782 |
for label, clusters in size_clusters.items():
|
|
@@ -786,18 +759,14 @@ class LayoutNormalizer:
|
|
| 786 |
element_to_cluster[id(elem)] = cluster
|
| 787 |
element_to_size_category[id(elem)] = category
|
| 788 |
|
| 789 |
-
# Step 4: Normalize each element
|
| 790 |
normalized_elements = []
|
| 791 |
|
| 792 |
for elem in self.elements:
|
| 793 |
-
# Get size cluster
|
| 794 |
cluster = element_to_cluster.get(id(elem), [elem])
|
| 795 |
size_category = element_to_size_category.get(id(elem), f"{elem.label}_default")
|
| 796 |
|
| 797 |
-
# Get normalized size
|
| 798 |
norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster)
|
| 799 |
|
| 800 |
-
# Create normalized bbox (centered on original)
|
| 801 |
center_x, center_y = elem.center_x, elem.center_y
|
| 802 |
norm_bbox = [
|
| 803 |
center_x - norm_width / 2,
|
|
@@ -806,7 +775,6 @@ class LayoutNormalizer:
|
|
| 806 |
center_y + norm_height / 2
|
| 807 |
]
|
| 808 |
|
| 809 |
-
# Snap to grid - preserve original size better
|
| 810 |
snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True)
|
| 811 |
grid_position = self.grid.get_grid_position(snapped_bbox)
|
| 812 |
|
|
@@ -817,12 +785,10 @@ class LayoutNormalizer:
|
|
| 817 |
size_category=size_category
|
| 818 |
))
|
| 819 |
|
| 820 |
-
# Step 5: Apply alignment corrections
|
| 821 |
normalized_elements = self._apply_alignment_corrections(
|
| 822 |
normalized_elements, h_alignments, v_alignments, edge_alignments
|
| 823 |
)
|
| 824 |
|
| 825 |
-
# Step 6: Resolve overlaps
|
| 826 |
overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height)
|
| 827 |
normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements)
|
| 828 |
|
|
@@ -835,32 +801,26 @@ class LayoutNormalizer:
|
|
| 835 |
edge_alignments: Dict) -> List[NormalizedElement]:
|
| 836 |
"""Apply alignment corrections to normalized elements."""
|
| 837 |
|
| 838 |
-
# Create lookup dictionary
|
| 839 |
elem_to_normalized = {id(ne.original): ne for ne in normalized_elements}
|
| 840 |
|
| 841 |
-
# Align horizontally grouped elements
|
| 842 |
for h_group in h_alignments:
|
| 843 |
norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized]
|
| 844 |
if len(norm_group) > 1:
|
| 845 |
-
# Align to average Y position
|
| 846 |
avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group)
|
| 847 |
for ne in norm_group:
|
| 848 |
height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
|
| 849 |
ne.normalized_bbox[1] = avg_y - height / 2
|
| 850 |
ne.normalized_bbox[3] = avg_y + height / 2
|
| 851 |
|
| 852 |
-
# Align vertically grouped elements
|
| 853 |
for v_group in v_alignments:
|
| 854 |
norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized]
|
| 855 |
if len(norm_group) > 1:
|
| 856 |
-
# Align to average X position
|
| 857 |
avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group)
|
| 858 |
for ne in norm_group:
|
| 859 |
width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
|
| 860 |
ne.normalized_bbox[0] = avg_x - width / 2
|
| 861 |
ne.normalized_bbox[2] = avg_x + width / 2
|
| 862 |
|
| 863 |
-
# Align edges
|
| 864 |
for edge_type, groups in edge_alignments.items():
|
| 865 |
for edge_group in groups:
|
| 866 |
norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized]
|
|
@@ -894,7 +854,7 @@ class LayoutNormalizer:
|
|
| 894 |
|
| 895 |
|
| 896 |
# ============================================================================
|
| 897 |
-
# VISUALIZATION & EXPORT
|
| 898 |
# ============================================================================
|
| 899 |
def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
| 900 |
normalized_elements: List[NormalizedElement],
|
|
@@ -903,7 +863,6 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
|
| 903 |
|
| 904 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
|
| 905 |
|
| 906 |
-
# Original detections
|
| 907 |
ax1.imshow(pil_img)
|
| 908 |
ax1.set_title("Original Predictions", fontsize=16, weight='bold')
|
| 909 |
ax1.axis('off')
|
|
@@ -918,12 +877,10 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
|
| 918 |
ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8,
|
| 919 |
bbox=dict(facecolor='white', alpha=0.7))
|
| 920 |
|
| 921 |
-
# Normalized layout
|
| 922 |
ax2.imshow(pil_img)
|
| 923 |
ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold')
|
| 924 |
ax2.axis('off')
|
| 925 |
|
| 926 |
-
# Draw grid
|
| 927 |
for x in range(grid_system.num_columns + 1):
|
| 928 |
x_pos = x * grid_system.cell_width
|
| 929 |
ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
|
|
@@ -931,7 +888,6 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
|
| 931 |
y_pos = y * grid_system.cell_height
|
| 932 |
ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
|
| 933 |
|
| 934 |
-
# Draw normalized elements
|
| 935 |
np.random.seed(42)
|
| 936 |
colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))
|
| 937 |
color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)}
|
|
@@ -940,14 +896,12 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
|
| 940 |
x1, y1, x2, y2 = norm_elem.normalized_bbox
|
| 941 |
color = color_map[norm_elem.original.label]
|
| 942 |
|
| 943 |
-
# Normalized box (thick)
|
| 944 |
rect = patches.Rectangle(
|
| 945 |
(x1, y1), x2 - x1, y2 - y1,
|
| 946 |
linewidth=3, edgecolor=color, facecolor='none'
|
| 947 |
)
|
| 948 |
ax2.add_patch(rect)
|
| 949 |
|
| 950 |
-
# Original box (thin, dashed)
|
| 951 |
ox1, oy1, ox2, oy2 = norm_elem.original.bbox
|
| 952 |
orig_rect = patches.Rectangle(
|
| 953 |
(ox1, oy1), ox2 - ox1, oy2 - oy1,
|
|
@@ -956,7 +910,6 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
|
| 956 |
)
|
| 957 |
ax2.add_patch(orig_rect)
|
| 958 |
|
| 959 |
-
# Label
|
| 960 |
grid_pos = norm_elem.grid_position
|
| 961 |
label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}"
|
| 962 |
ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7,
|
|
@@ -1086,7 +1039,6 @@ def export_to_html(normalized_elements: List[NormalizedElement],
|
|
| 1086 |
text-transform: uppercase;
|
| 1087 |
}}
|
| 1088 |
|
| 1089 |
-
/* Element type specific styles */
|
| 1090 |
.button {{
|
| 1091 |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1092 |
color: white;
|
|
@@ -1160,7 +1112,7 @@ def export_to_html(normalized_elements: List[NormalizedElement],
|
|
| 1160 |
</head>
|
| 1161 |
<body>
|
| 1162 |
<div class="info-panel">
|
| 1163 |
-
<h3>
|
| 1164 |
<p><strong>Grid:</strong> {grid_cols} Γ {grid_rows}</p>
|
| 1165 |
<p><strong>Elements:</strong> {total_elements}</p>
|
| 1166 |
<p><strong>Dimensions:</strong> {img_width}px Γ {img_height}px</p>
|
|
@@ -1208,7 +1160,7 @@ def export_to_html(normalized_elements: List[NormalizedElement],
|
|
| 1208 |
|
| 1209 |
|
| 1210 |
# ============================================================================
|
| 1211 |
-
# MAIN PIPELINE
|
| 1212 |
# ============================================================================
|
| 1213 |
def process_wireframe(image_path: str,
|
| 1214 |
save_json: bool = True,
|
|
@@ -1230,51 +1182,46 @@ def process_wireframe(image_path: str,
|
|
| 1230 |
print("=== PROCESS_WIREFRAME START ===")
|
| 1231 |
print("Input image path:", image_path)
|
| 1232 |
print("File exists:", os.path.exists(image_path))
|
| 1233 |
-
|
|
|
|
| 1234 |
|
| 1235 |
print("=" * 80)
|
| 1236 |
-
print("π WIREFRAME LAYOUT NORMALIZER")
|
| 1237 |
print("=" * 80)
|
| 1238 |
|
| 1239 |
-
# Step 1: Load model and get predictions
|
| 1240 |
-
global
|
| 1241 |
-
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
print("
|
| 1245 |
-
print("Attempting to load keras model:", MODEL_PATH)
|
| 1246 |
-
print("Loaded model summary:")
|
| 1247 |
-
model.summary(print_fn=lambda x: print(x))
|
| 1248 |
try:
|
| 1249 |
-
|
| 1250 |
-
|
| 1251 |
-
|
| 1252 |
-
)
|
| 1253 |
-
print("
|
|
|
|
| 1254 |
except Exception as e:
|
| 1255 |
-
print(f"β Error loading model: {e}")
|
| 1256 |
-
|
| 1257 |
-
try:
|
| 1258 |
-
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
|
| 1259 |
-
print("β
Model loaded successfully (without compilation)!")
|
| 1260 |
-
except Exception as e2:
|
| 1261 |
-
print(f"β Failed to load model: {e2}")
|
| 1262 |
-
return {}
|
| 1263 |
|
| 1264 |
print(f"\nπΈ Processing image: {image_path}")
|
| 1265 |
print("Running detection inferenceβ¦")
|
| 1266 |
-
|
| 1267 |
-
|
| 1268 |
-
print("
|
| 1269 |
-
|
| 1270 |
-
|
|
|
|
|
|
|
|
|
|
| 1271 |
|
| 1272 |
if not elements:
|
| 1273 |
-
print("β οΈ No
|
| 1274 |
-
print("β Meaning model.predict returned zero raw boxes.")
|
| 1275 |
print("β Check thresholds:")
|
| 1276 |
-
print("CONF_THRESHOLD:
|
| 1277 |
-
print("IOU_THRESHOLD:
|
| 1278 |
return {}
|
| 1279 |
|
| 1280 |
# Step 2: Normalize layout
|
|
@@ -1314,7 +1261,6 @@ def process_wireframe(image_path: str,
|
|
| 1314 |
print("π PROCESSING SUMMARY")
|
| 1315 |
print("=" * 80)
|
| 1316 |
|
| 1317 |
-
# Count by type
|
| 1318 |
type_counts = {}
|
| 1319 |
for elem in elements:
|
| 1320 |
type_counts[elem.label] = type_counts.get(elem.label, 0) + 1
|
|
@@ -1323,18 +1269,16 @@ def process_wireframe(image_path: str,
|
|
| 1323 |
for elem_type, count in sorted(type_counts.items()):
|
| 1324 |
print(f" β’ {elem_type}: {count}")
|
| 1325 |
|
| 1326 |
-
# Size categories
|
| 1327 |
size_categories = {}
|
| 1328 |
for norm_elem in normalized_elements:
|
| 1329 |
size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1
|
| 1330 |
|
| 1331 |
-
print(f"\n
|
| 1332 |
|
| 1333 |
-
# Alignment info
|
| 1334 |
h_alignments = normalizer.alignment_detector.detect_horizontal_alignments()
|
| 1335 |
v_alignments = normalizer.alignment_detector.detect_vertical_alignments()
|
| 1336 |
|
| 1337 |
-
print(f"\n
|
| 1338 |
print(f" β’ Horizontal groups: {len(h_alignments)}")
|
| 1339 |
print(f" β’ Vertical groups: {len(v_alignments)}")
|
| 1340 |
|
|
@@ -1383,7 +1327,6 @@ def batch_process(image_dir: str, pattern: str = "*.png"):
|
|
| 1383 |
'error': str(e)
|
| 1384 |
})
|
| 1385 |
|
| 1386 |
-
# Summary
|
| 1387 |
successful = sum(1 for r in all_results if r['success'])
|
| 1388 |
print(f"\n{'=' * 80}")
|
| 1389 |
print(f"π BATCH PROCESSING COMPLETE")
|
|
|
|
| 1 |
+
import onnxruntime as ort
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import matplotlib.patches as patches
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
# ============================================================================
|
| 14 |
+
# CONFIGURATION - UPDATED FOR ONNX
|
| 15 |
# ============================================================================
|
| 16 |
+
MODEL_PATH = "./wireframe_detection_model_best_700.onnx" # Changed to .onnx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
OUTPUT_DIR = "./output/"
|
| 18 |
CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"]
|
| 19 |
|
|
|
|
| 21 |
CONF_THRESHOLD = 0.1
|
| 22 |
IOU_THRESHOLD = 0.1
|
| 23 |
|
| 24 |
+
# Layout Configuration
|
| 25 |
+
GRID_COLUMNS = 24
|
| 26 |
+
ALIGNMENT_THRESHOLD = 10
|
| 27 |
+
SIZE_CLUSTERING_THRESHOLD = 15
|
| 28 |
|
| 29 |
+
# Standard sizes for each element type (relative units)
|
| 30 |
STANDARD_SIZES = {
|
| 31 |
+
'button': {'width': 2, 'height': 1},
|
| 32 |
+
'checkbox': {'width': 1, 'height': 1},
|
| 33 |
+
'textfield': {'width': 5, 'height': 1},
|
| 34 |
+
'text': {'width': 3, 'height': 1},
|
| 35 |
+
'paragraph': {'width': 8, 'height': 2},
|
| 36 |
+
'image': {'width': 4, 'height': 4},
|
| 37 |
+
'navbar': {'width': 24, 'height': 1}
|
| 38 |
}
|
| 39 |
|
| 40 |
+
ort_session = None # Changed from model to ort_session
|
| 41 |
|
| 42 |
|
| 43 |
# ============================================================================
|
| 44 |
+
# UTILITY FUNCTIONS FOR ONNX
|
| 45 |
+
# ============================================================================
|
| 46 |
+
def sigmoid(x):
|
| 47 |
+
"""Sigmoid activation function."""
|
| 48 |
+
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def softmax(x, axis=-1):
|
| 52 |
+
"""Softmax activation function."""
|
| 53 |
+
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
| 54 |
+
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def non_max_suppression_numpy(boxes, scores, iou_threshold=0.5, score_threshold=0.1):
|
| 58 |
+
"""
|
| 59 |
+
Pure NumPy implementation of Non-Maximum Suppression.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
boxes: Array of shape (N, 4) with format [x1, y1, x2, y2]
|
| 63 |
+
scores: Array of shape (N,) with confidence scores
|
| 64 |
+
iou_threshold: IoU threshold for suppression
|
| 65 |
+
score_threshold: Minimum score threshold
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List of indices to keep
|
| 69 |
+
"""
|
| 70 |
+
if len(boxes) == 0:
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
# Filter by score threshold
|
| 74 |
+
keep_mask = scores >= score_threshold
|
| 75 |
+
boxes = boxes[keep_mask]
|
| 76 |
+
scores = scores[keep_mask]
|
| 77 |
+
|
| 78 |
+
if len(boxes) == 0:
|
| 79 |
+
return []
|
| 80 |
+
|
| 81 |
+
# Get coordinates
|
| 82 |
+
x1 = boxes[:, 0]
|
| 83 |
+
y1 = boxes[:, 1]
|
| 84 |
+
x2 = boxes[:, 2]
|
| 85 |
+
y2 = boxes[:, 3]
|
| 86 |
+
|
| 87 |
+
# Calculate areas
|
| 88 |
+
areas = (x2 - x1) * (y2 - y1)
|
| 89 |
+
|
| 90 |
+
# Sort by scores
|
| 91 |
+
order = scores.argsort()[::-1]
|
| 92 |
+
|
| 93 |
+
keep = []
|
| 94 |
+
while order.size > 0:
|
| 95 |
+
# Pick the box with highest score
|
| 96 |
+
i = order[0]
|
| 97 |
+
keep.append(i)
|
| 98 |
+
|
| 99 |
+
# Calculate IoU with remaining boxes
|
| 100 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 101 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 102 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 103 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 104 |
+
|
| 105 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 106 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 107 |
+
|
| 108 |
+
intersection = w * h
|
| 109 |
+
iou = intersection / (areas[i] + areas[order[1:]] - intersection)
|
| 110 |
+
|
| 111 |
+
# Keep boxes with IoU less than threshold
|
| 112 |
+
inds = np.where(iou <= iou_threshold)[0]
|
| 113 |
+
order = order[inds + 1]
|
| 114 |
+
|
| 115 |
+
return keep
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ============================================================================
|
| 119 |
+
# DATA STRUCTURES (unchanged)
|
| 120 |
# ============================================================================
|
| 121 |
@dataclass
|
| 122 |
class Element:
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
# ============================================================================
|
| 150 |
+
# PREDICTION EXTRACTION - MODIFIED FOR ONNX
|
| 151 |
# ============================================================================
|
| 152 |
def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
|
| 153 |
+
"""Extract predictions from the ONNX model."""
|
| 154 |
+
global ort_session
|
| 155 |
+
if ort_session is None:
|
| 156 |
+
raise ValueError("ONNX model not loaded. Please load the model first.")
|
| 157 |
|
| 158 |
# Load and preprocess image
|
| 159 |
pil_img = Image.open(image_path).convert("RGB")
|
|
|
|
| 163 |
img_array = np.array(resized_img, dtype=np.float32) / 255.0
|
| 164 |
input_tensor = np.expand_dims(img_array, axis=0)
|
| 165 |
|
| 166 |
+
# Get predictions from ONNX model
|
| 167 |
+
input_name = ort_session.get_inputs()[0].name
|
| 168 |
+
output_name = ort_session.get_outputs()[0].name
|
| 169 |
+
pred_grid = ort_session.run([output_name], {input_name: input_tensor})[0][0]
|
| 170 |
+
|
| 171 |
raw_boxes = []
|
| 172 |
S = pred_grid.shape[0]
|
| 173 |
cell_size = 1.0 / S
|
| 174 |
|
| 175 |
for row in range(S):
|
| 176 |
for col in range(S):
|
| 177 |
+
obj_score = float(sigmoid(pred_grid[row, col, 0]))
|
| 178 |
if obj_score < CONF_THRESHOLD:
|
| 179 |
continue
|
| 180 |
|
| 181 |
+
x_offset = float(sigmoid(pred_grid[row, col, 1]))
|
| 182 |
+
y_offset = float(sigmoid(pred_grid[row, col, 2]))
|
| 183 |
+
width = float(sigmoid(pred_grid[row, col, 3]))
|
| 184 |
+
height = float(sigmoid(pred_grid[row, col, 4]))
|
| 185 |
|
| 186 |
class_logits = pred_grid[row, col, 5:]
|
| 187 |
+
class_probs = softmax(class_logits)
|
| 188 |
class_id = int(np.argmax(class_probs))
|
| 189 |
class_conf = float(class_probs[class_id])
|
| 190 |
final_score = obj_score * class_conf
|
|
|
|
| 202 |
if x2 > x1 and y2 > y1:
|
| 203 |
raw_boxes.append((class_id, final_score, x1, y1, x2, y2))
|
| 204 |
|
| 205 |
+
# Apply NMS per class using NumPy implementation
|
| 206 |
elements = []
|
| 207 |
for class_id in range(len(CLASS_NAMES)):
|
| 208 |
class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id]
|
| 209 |
if not class_boxes:
|
| 210 |
continue
|
| 211 |
|
| 212 |
+
scores = np.array([b[0] for b in class_boxes])
|
| 213 |
+
boxes_xyxy = np.array([[b[1], b[2], b[3], b[4]] for b in class_boxes])
|
| 214 |
|
| 215 |
+
selected_indices = non_max_suppression_numpy(
|
| 216 |
boxes=boxes_xyxy,
|
| 217 |
scores=scores,
|
|
|
|
| 218 |
iou_threshold=IOU_THRESHOLD,
|
| 219 |
score_threshold=CONF_THRESHOLD
|
| 220 |
)
|
| 221 |
|
| 222 |
+
for idx in selected_indices:
|
| 223 |
score, x1, y1, x2, y2 = class_boxes[idx]
|
| 224 |
elements.append(Element(
|
| 225 |
label=CLASS_NAMES[class_id],
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
# ============================================================================
|
| 234 |
+
# ALIGNMENT DETECTION (unchanged)
|
| 235 |
# ============================================================================
|
| 236 |
class AlignmentDetector:
|
| 237 |
"""Detects alignment relationships between elements."""
|
|
|
|
| 343 |
|
| 344 |
|
| 345 |
# ============================================================================
|
| 346 |
+
# SIZE NORMALIZATION (unchanged)
|
| 347 |
# ============================================================================
|
| 348 |
class SizeNormalizer:
|
| 349 |
"""Normalizes element sizes based on type and clustering."""
|
|
|
|
| 394 |
return clusters
|
| 395 |
|
| 396 |
def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]:
|
| 397 |
+
"""Get normalized size for an element based on its cluster."""
|
|
|
|
|
|
|
| 398 |
if len(size_cluster) >= 3:
|
|
|
|
| 399 |
widths = sorted([e.width for e in size_cluster])
|
| 400 |
heights = sorted([e.height for e in size_cluster])
|
| 401 |
median_width = widths[len(widths) // 2]
|
| 402 |
median_height = heights[len(heights) // 2]
|
| 403 |
|
|
|
|
| 404 |
if abs(element.width - median_width) / median_width < 0.3:
|
| 405 |
normalized_width = round(median_width)
|
| 406 |
else:
|
|
|
|
| 411 |
else:
|
| 412 |
normalized_height = round(element.height)
|
| 413 |
else:
|
|
|
|
| 414 |
normalized_width = round(element.width)
|
| 415 |
normalized_height = round(element.height)
|
| 416 |
|
|
|
|
| 418 |
|
| 419 |
|
| 420 |
# ============================================================================
|
| 421 |
+
# GRID-BASED LAYOUT SYSTEM (unchanged)
|
| 422 |
# ============================================================================
|
| 423 |
class GridLayoutSystem:
|
| 424 |
"""Grid-based layout system for precise positioning."""
|
|
|
|
| 437 |
print(f"π Cell size: {self.cell_width:.1f}px Γ {self.cell_height:.1f}px")
|
| 438 |
|
| 439 |
def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]:
|
| 440 |
+
"""Snap bounding box to grid."""
|
| 441 |
x1, y1, x2, y2 = bbox
|
| 442 |
original_width = x2 - x1
|
| 443 |
original_height = y2 - y1
|
| 444 |
|
|
|
|
| 445 |
center_x = (x1 + x2) / 2
|
| 446 |
center_y = (y1 + y2) / 2
|
| 447 |
|
|
|
|
| 448 |
center_col = round(center_x / self.cell_width)
|
| 449 |
center_row = round(center_y / self.cell_height)
|
| 450 |
|
| 451 |
if preserve_size:
|
|
|
|
| 452 |
width_cells = max(1, round(original_width / self.cell_width))
|
| 453 |
height_cells = max(1, round(original_height / self.cell_height))
|
| 454 |
else:
|
|
|
|
| 455 |
standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1})
|
| 456 |
width_cells = max(1, round(original_width / self.cell_width))
|
| 457 |
height_cells = max(1, round(original_height / self.cell_height))
|
| 458 |
|
|
|
|
| 459 |
if abs(width_cells - standard['width']) <= 0.5:
|
| 460 |
width_cells = standard['width']
|
| 461 |
if abs(height_cells - standard['height']) <= 0.5:
|
| 462 |
height_cells = standard['height']
|
| 463 |
|
|
|
|
| 464 |
start_col = center_col - width_cells // 2
|
| 465 |
start_row = center_row - height_cells // 2
|
| 466 |
|
|
|
|
| 467 |
start_col = max(0, min(start_col, self.num_columns - width_cells))
|
| 468 |
start_row = max(0, min(start_row, self.num_rows - height_cells))
|
| 469 |
|
|
|
|
| 470 |
snapped_x1 = start_col * self.cell_width
|
| 471 |
snapped_y1 = start_row * self.cell_height
|
| 472 |
snapped_x2 = (start_col + width_cells) * self.cell_width
|
|
|
|
| 494 |
|
| 495 |
|
| 496 |
# ============================================================================
|
| 497 |
+
# OVERLAP DETECTION & RESOLUTION (unchanged)
|
| 498 |
# ============================================================================
|
| 499 |
class OverlapResolver:
|
| 500 |
"""Detects and resolves overlapping elements."""
|
|
|
|
| 503 |
self.elements = elements
|
| 504 |
self.img_width = img_width
|
| 505 |
self.img_height = img_height
|
| 506 |
+
self.overlap_threshold = 0.2
|
| 507 |
|
| 508 |
def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
|
| 509 |
"""Compute Intersection over Union between two bounding boxes."""
|
|
|
|
| 542 |
return overlap_ratio1, overlap_ratio2
|
| 543 |
|
| 544 |
def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]:
|
| 545 |
+
"""Resolve overlaps by adjusting element positions."""
|
| 546 |
print("\nπ Checking for overlaps...")
|
| 547 |
|
| 548 |
overlaps = []
|
|
|
|
| 576 |
|
| 577 |
print(f"β οΈ Found {len(overlaps)} overlapping element pairs")
|
| 578 |
|
|
|
|
| 579 |
overlaps.sort(key=lambda x: x['overlap'], reverse=True)
|
| 580 |
|
| 581 |
elements_to_remove = set()
|
|
|
|
| 591 |
elem2 = overlap_info['elem2']
|
| 592 |
overlap_ratio = overlap_info['overlap']
|
| 593 |
|
|
|
|
| 594 |
if overlap_ratio > 0.7:
|
| 595 |
if elem1.original.score < elem2.original.score:
|
| 596 |
elements_to_remove.add(idx1)
|
|
|
|
| 601 |
print(f" ποΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - "
|
| 602 |
f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}")
|
| 603 |
|
|
|
|
| 604 |
elif overlap_ratio > 0.4:
|
| 605 |
self._try_separate_elements(elem1, elem2, overlap_info)
|
| 606 |
print(f" βοΈ Separating {elem1.original.label} and {elem2.original.label} "
|
| 607 |
f"(overlap: {overlap_ratio * 100:.1f}%)")
|
| 608 |
|
|
|
|
| 609 |
else:
|
| 610 |
self._shrink_overlapping_edges(elem1, elem2, overlap_info)
|
| 611 |
print(f" π Shrinking {elem1.original.label} and {elem2.original.label} "
|
|
|
|
| 622 |
|
| 623 |
def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement,
|
| 624 |
overlap_info: Dict):
|
| 625 |
+
"""Try to separate two significantly overlapping elements."""
|
| 626 |
bbox1 = elem1.normalized_bbox
|
| 627 |
bbox2 = elem2.normalized_bbox
|
| 628 |
|
|
|
|
| 629 |
overlap_x1 = max(bbox1[0], bbox2[0])
|
| 630 |
overlap_y1 = max(bbox1[1], bbox2[1])
|
| 631 |
overlap_x2 = min(bbox1[2], bbox2[2])
|
|
|
|
| 634 |
overlap_width = overlap_x2 - overlap_x1
|
| 635 |
overlap_height = overlap_y2 - overlap_y1
|
| 636 |
|
|
|
|
| 637 |
center1_x = (bbox1[0] + bbox1[2]) / 2
|
| 638 |
center1_y = (bbox1[1] + bbox1[3]) / 2
|
| 639 |
center2_x = (bbox2[0] + bbox2[2]) / 2
|
| 640 |
center2_y = (bbox2[1] + bbox2[3]) / 2
|
| 641 |
|
|
|
|
| 642 |
dx = abs(center2_x - center1_x)
|
| 643 |
dy = abs(center2_y - center1_y)
|
| 644 |
|
| 645 |
+
min_gap = 3
|
|
|
|
| 646 |
|
| 647 |
if dx > dy:
|
|
|
|
| 648 |
if center1_x < center2_x:
|
|
|
|
| 649 |
midpoint = (bbox1[2] + bbox2[0]) / 2
|
| 650 |
bbox1[2] = midpoint - min_gap
|
| 651 |
bbox2[0] = midpoint + min_gap
|
| 652 |
else:
|
|
|
|
| 653 |
midpoint = (bbox2[2] + bbox1[0]) / 2
|
| 654 |
bbox2[2] = midpoint - min_gap
|
| 655 |
bbox1[0] = midpoint + min_gap
|
| 656 |
else:
|
|
|
|
| 657 |
if center1_y < center2_y:
|
|
|
|
| 658 |
midpoint = (bbox1[3] + bbox2[1]) / 2
|
| 659 |
bbox1[3] = midpoint - min_gap
|
| 660 |
bbox2[1] = midpoint + min_gap
|
| 661 |
else:
|
|
|
|
| 662 |
midpoint = (bbox2[3] + bbox1[1]) / 2
|
| 663 |
bbox2[3] = midpoint - min_gap
|
| 664 |
bbox1[1] = midpoint + min_gap
|
| 665 |
|
|
|
|
| 666 |
self._ensure_valid_bbox(bbox1)
|
| 667 |
self._ensure_valid_bbox(bbox2)
|
| 668 |
|
|
|
|
| 672 |
bbox1 = elem1.normalized_bbox
|
| 673 |
bbox2 = elem2.normalized_bbox
|
| 674 |
|
|
|
|
| 675 |
overlap_x1 = max(bbox1[0], bbox2[0])
|
| 676 |
overlap_y1 = max(bbox1[1], bbox2[1])
|
| 677 |
overlap_x2 = min(bbox1[2], bbox2[2])
|
|
|
|
| 680 |
overlap_width = overlap_x2 - overlap_x1
|
| 681 |
overlap_height = overlap_y2 - overlap_y1
|
| 682 |
|
| 683 |
+
gap = 2
|
|
|
|
| 684 |
|
| 685 |
if overlap_width > overlap_height:
|
|
|
|
| 686 |
shrink = overlap_width / 2 + gap
|
| 687 |
if bbox1[0] < bbox2[0]:
|
| 688 |
bbox1[2] -= shrink
|
|
|
|
| 691 |
bbox2[2] -= shrink
|
| 692 |
bbox1[0] += shrink
|
| 693 |
else:
|
|
|
|
| 694 |
shrink = overlap_height / 2 + gap
|
| 695 |
if bbox1[1] < bbox2[1]:
|
| 696 |
bbox1[3] -= shrink
|
|
|
|
| 704 |
|
| 705 |
def _ensure_valid_bbox(self, bbox: List[float]):
|
| 706 |
"""Ensure bounding box has minimum size and is within image bounds."""
|
| 707 |
+
min_size = 8
|
| 708 |
|
|
|
|
| 709 |
if bbox[2] - bbox[0] < min_size:
|
| 710 |
center_x = (bbox[0] + bbox[2]) / 2
|
| 711 |
bbox[0] = center_x - min_size / 2
|
|
|
|
| 716 |
bbox[1] = center_y - min_size / 2
|
| 717 |
bbox[3] = center_y + min_size / 2
|
| 718 |
|
|
|
|
| 719 |
bbox[0] = max(0, min(bbox[0], self.img_width))
|
| 720 |
bbox[1] = max(0, min(bbox[1], self.img_height))
|
| 721 |
bbox[2] = max(0, min(bbox[2], self.img_width))
|
|
|
|
| 723 |
|
| 724 |
|
| 725 |
# ============================================================================
|
| 726 |
+
# MAIN NORMALIZATION ENGINE (unchanged)
|
| 727 |
# ============================================================================
|
| 728 |
class LayoutNormalizer:
|
| 729 |
"""Main engine for normalizing wireframe layout."""
|
|
|
|
| 740 |
"""Normalize all elements with proper sizing and alignment."""
|
| 741 |
print("\nπ§ Starting layout normalization...")
|
| 742 |
|
|
|
|
| 743 |
h_alignments = self.alignment_detector.detect_horizontal_alignments()
|
| 744 |
v_alignments = self.alignment_detector.detect_vertical_alignments()
|
| 745 |
edge_alignments = self.alignment_detector.detect_edge_alignments()
|
|
|
|
| 747 |
print(f"β Found {len(h_alignments)} horizontal alignment groups")
|
| 748 |
print(f"β Found {len(v_alignments)} vertical alignment groups")
|
| 749 |
|
|
|
|
| 750 |
size_clusters = self.size_normalizer.cluster_sizes_by_type()
|
| 751 |
print(f"β Created size clusters for {len(size_clusters)} element types")
|
| 752 |
|
|
|
|
| 753 |
element_to_cluster = {}
|
| 754 |
element_to_size_category = {}
|
| 755 |
for label, clusters in size_clusters.items():
|
|
|
|
| 759 |
element_to_cluster[id(elem)] = cluster
|
| 760 |
element_to_size_category[id(elem)] = category
|
| 761 |
|
|
|
|
| 762 |
normalized_elements = []
|
| 763 |
|
| 764 |
for elem in self.elements:
|
|
|
|
| 765 |
cluster = element_to_cluster.get(id(elem), [elem])
|
| 766 |
size_category = element_to_size_category.get(id(elem), f"{elem.label}_default")
|
| 767 |
|
|
|
|
| 768 |
norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster)
|
| 769 |
|
|
|
|
| 770 |
center_x, center_y = elem.center_x, elem.center_y
|
| 771 |
norm_bbox = [
|
| 772 |
center_x - norm_width / 2,
|
|
|
|
| 775 |
center_y + norm_height / 2
|
| 776 |
]
|
| 777 |
|
|
|
|
| 778 |
snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True)
|
| 779 |
grid_position = self.grid.get_grid_position(snapped_bbox)
|
| 780 |
|
|
|
|
| 785 |
size_category=size_category
|
| 786 |
))
|
| 787 |
|
|
|
|
| 788 |
normalized_elements = self._apply_alignment_corrections(
|
| 789 |
normalized_elements, h_alignments, v_alignments, edge_alignments
|
| 790 |
)
|
| 791 |
|
|
|
|
| 792 |
overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height)
|
| 793 |
normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements)
|
| 794 |
|
|
|
|
| 801 |
edge_alignments: Dict) -> List[NormalizedElement]:
|
| 802 |
"""Apply alignment corrections to normalized elements."""
|
| 803 |
|
|
|
|
| 804 |
elem_to_normalized = {id(ne.original): ne for ne in normalized_elements}
|
| 805 |
|
|
|
|
| 806 |
for h_group in h_alignments:
|
| 807 |
norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized]
|
| 808 |
if len(norm_group) > 1:
|
|
|
|
| 809 |
avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group)
|
| 810 |
for ne in norm_group:
|
| 811 |
height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
|
| 812 |
ne.normalized_bbox[1] = avg_y - height / 2
|
| 813 |
ne.normalized_bbox[3] = avg_y + height / 2
|
| 814 |
|
|
|
|
| 815 |
for v_group in v_alignments:
|
| 816 |
norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized]
|
| 817 |
if len(norm_group) > 1:
|
|
|
|
| 818 |
avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group)
|
| 819 |
for ne in norm_group:
|
| 820 |
width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
|
| 821 |
ne.normalized_bbox[0] = avg_x - width / 2
|
| 822 |
ne.normalized_bbox[2] = avg_x + width / 2
|
| 823 |
|
|
|
|
| 824 |
for edge_type, groups in edge_alignments.items():
|
| 825 |
for edge_group in groups:
|
| 826 |
norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized]
|
|
|
|
| 854 |
|
| 855 |
|
| 856 |
# ============================================================================
|
| 857 |
+
# VISUALIZATION & EXPORT (unchanged)
|
| 858 |
# ============================================================================
|
| 859 |
def visualize_comparison(pil_img: Image.Image, elements: List[Element],
|
| 860 |
normalized_elements: List[NormalizedElement],
|
|
|
|
| 863 |
|
| 864 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
|
| 865 |
|
|
|
|
| 866 |
ax1.imshow(pil_img)
|
| 867 |
ax1.set_title("Original Predictions", fontsize=16, weight='bold')
|
| 868 |
ax1.axis('off')
|
|
|
|
| 877 |
ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8,
|
| 878 |
bbox=dict(facecolor='white', alpha=0.7))
|
| 879 |
|
|
|
|
| 880 |
ax2.imshow(pil_img)
|
| 881 |
ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold')
|
| 882 |
ax2.axis('off')
|
| 883 |
|
|
|
|
| 884 |
for x in range(grid_system.num_columns + 1):
|
| 885 |
x_pos = x * grid_system.cell_width
|
| 886 |
ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
|
|
|
|
| 888 |
y_pos = y * grid_system.cell_height
|
| 889 |
ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
|
| 890 |
|
|
|
|
| 891 |
np.random.seed(42)
|
| 892 |
colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))
|
| 893 |
color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)}
|
|
|
|
| 896 |
x1, y1, x2, y2 = norm_elem.normalized_bbox
|
| 897 |
color = color_map[norm_elem.original.label]
|
| 898 |
|
|
|
|
| 899 |
rect = patches.Rectangle(
|
| 900 |
(x1, y1), x2 - x1, y2 - y1,
|
| 901 |
linewidth=3, edgecolor=color, facecolor='none'
|
| 902 |
)
|
| 903 |
ax2.add_patch(rect)
|
| 904 |
|
|
|
|
| 905 |
ox1, oy1, ox2, oy2 = norm_elem.original.bbox
|
| 906 |
orig_rect = patches.Rectangle(
|
| 907 |
(ox1, oy1), ox2 - ox1, oy2 - oy1,
|
|
|
|
| 910 |
)
|
| 911 |
ax2.add_patch(orig_rect)
|
| 912 |
|
|
|
|
| 913 |
grid_pos = norm_elem.grid_position
|
| 914 |
label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}"
|
| 915 |
ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7,
|
|
|
|
| 1039 |
text-transform: uppercase;
|
| 1040 |
}}
|
| 1041 |
|
|
|
|
| 1042 |
.button {{
|
| 1043 |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1044 |
color: white;
|
|
|
|
| 1112 |
</head>
|
| 1113 |
<body>
|
| 1114 |
<div class="info-panel">
|
| 1115 |
+
<h3>π Layout Info</h3>
|
| 1116 |
<p><strong>Grid:</strong> {grid_cols} Γ {grid_rows}</p>
|
| 1117 |
<p><strong>Elements:</strong> {total_elements}</p>
|
| 1118 |
<p><strong>Dimensions:</strong> {img_width}px Γ {img_height}px</p>
|
|
|
|
| 1160 |
|
| 1161 |
|
| 1162 |
# ============================================================================
|
| 1163 |
+
# MAIN PIPELINE - MODIFIED FOR ONNX
|
| 1164 |
# ============================================================================
|
| 1165 |
def process_wireframe(image_path: str,
|
| 1166 |
save_json: bool = True,
|
|
|
|
| 1182 |
print("=== PROCESS_WIREFRAME START ===")
|
| 1183 |
print("Input image path:", image_path)
|
| 1184 |
print("File exists:", os.path.exists(image_path))
|
| 1185 |
+
if os.path.exists(image_path):
|
| 1186 |
+
print("File size:", os.path.getsize(image_path))
|
| 1187 |
|
| 1188 |
print("=" * 80)
|
| 1189 |
+
print("π WIREFRAME LAYOUT NORMALIZER (ONNX)")
|
| 1190 |
print("=" * 80)
|
| 1191 |
|
| 1192 |
+
# Step 1: Load ONNX model and get predictions
|
| 1193 |
+
global ort_session
|
| 1194 |
+
if ort_session is None:
|
| 1195 |
+
print("\nπ¦ Loading ONNX model...")
|
| 1196 |
+
print("Model path:", MODEL_PATH)
|
| 1197 |
+
print("Model path exists?", os.path.exists(MODEL_PATH))
|
|
|
|
|
|
|
|
|
|
| 1198 |
try:
|
| 1199 |
+
ort_session = ort.InferenceSession(MODEL_PATH)
|
| 1200 |
+
print("β
ONNX model loaded successfully!")
|
| 1201 |
+
print(f"Input name: {ort_session.get_inputs()[0].name}")
|
| 1202 |
+
print(f"Input shape: {ort_session.get_inputs()[0].shape}")
|
| 1203 |
+
print(f"Output name: {ort_session.get_outputs()[0].name}")
|
| 1204 |
+
print(f"Output shape: {ort_session.get_outputs()[0].shape}")
|
| 1205 |
except Exception as e:
|
| 1206 |
+
print(f"β Error loading ONNX model: {e}")
|
| 1207 |
+
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1208 |
|
| 1209 |
print(f"\nπΈ Processing image: {image_path}")
|
| 1210 |
print("Running detection inferenceβ¦")
|
| 1211 |
+
try:
|
| 1212 |
+
pil_img, elements = get_predictions(image_path)
|
| 1213 |
+
print(f"β
Detected {len(elements)} elements")
|
| 1214 |
+
for elem in elements:
|
| 1215 |
+
print(f" - {elem.label} (conf: {elem.score:.3f}) at {elem.bbox}")
|
| 1216 |
+
except Exception as e:
|
| 1217 |
+
print(f"β Error during prediction: {e}")
|
| 1218 |
+
return {}
|
| 1219 |
|
| 1220 |
if not elements:
|
| 1221 |
+
print("β οΈ No elements detected.")
|
|
|
|
| 1222 |
print("β Check thresholds:")
|
| 1223 |
+
print(f" CONF_THRESHOLD: {CONF_THRESHOLD}")
|
| 1224 |
+
print(f" IOU_THRESHOLD: {IOU_THRESHOLD}")
|
| 1225 |
return {}
|
| 1226 |
|
| 1227 |
# Step 2: Normalize layout
|
|
|
|
| 1261 |
print("π PROCESSING SUMMARY")
|
| 1262 |
print("=" * 80)
|
| 1263 |
|
|
|
|
| 1264 |
type_counts = {}
|
| 1265 |
for elem in elements:
|
| 1266 |
type_counts[elem.label] = type_counts.get(elem.label, 0) + 1
|
|
|
|
| 1269 |
for elem_type, count in sorted(type_counts.items()):
|
| 1270 |
print(f" β’ {elem_type}: {count}")
|
| 1271 |
|
|
|
|
| 1272 |
size_categories = {}
|
| 1273 |
for norm_elem in normalized_elements:
|
| 1274 |
size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1
|
| 1275 |
|
| 1276 |
+
print(f"\nπ Size Categories: {len(size_categories)}")
|
| 1277 |
|
|
|
|
| 1278 |
h_alignments = normalizer.alignment_detector.detect_horizontal_alignments()
|
| 1279 |
v_alignments = normalizer.alignment_detector.detect_vertical_alignments()
|
| 1280 |
|
| 1281 |
+
print(f"\nπ Alignment:")
|
| 1282 |
print(f" β’ Horizontal groups: {len(h_alignments)}")
|
| 1283 |
print(f" β’ Vertical groups: {len(v_alignments)}")
|
| 1284 |
|
|
|
|
| 1327 |
'error': str(e)
|
| 1328 |
})
|
| 1329 |
|
|
|
|
| 1330 |
successful = sum(1 for r in all_results if r['success'])
|
| 1331 |
print(f"\n{'=' * 80}")
|
| 1332 |
print(f"π BATCH PROCESSING COMPLETE")
|
wireframe_detection_model_best_700.keras β wireframe.onnx
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c47e1f0f63b4a29dd146331c582860e5981ea0546119b79511a167e856a6277
|
| 3 |
+
size 17701338
|