ChristianQ commited on
Commit
e9e7a8d
Β·
1 Parent(s): 9c67a05

Updated detection module script and requirements.txt for ONNX models than Keras

Browse files
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
- tensorflow-cpu==2.15.0
2
- gradio
3
- fastapi
4
- uvicorn
5
  numpy
6
  pillow
7
- matplotlib
8
  opencv-python-headless
9
- python-multipart
 
 
 
 
 
1
+ onnxruntime==1.19.0
 
 
 
2
  numpy
3
  pillow
 
4
  opencv-python-headless
5
+ fastapi
6
+ uvicorn
7
+ gradio
8
+ python-multipart
9
+ matplotlib
visualization.py CHANGED
@@ -1,4 +1,4 @@
1
- import tensorflow as tf
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib.patches as patches
@@ -11,76 +11,9 @@ from typing import List, Tuple, Dict, Optional
11
 
12
 
13
  # ============================================================================
14
- # CUSTOM LOSS CLASS (Required for model loading)
15
  # ============================================================================
16
- @tf.keras.utils.register_keras_serializable()
17
- class LossCalculation(tf.keras.losses.Loss):
18
- """Custom loss function for wireframe detection."""
19
-
20
- def __init__(self, num_classes=7, lambda_coord=5.0, lambda_noobj=0.5,
21
- name='loss_calculation', reduction='sum_over_batch_size', **kwargs):
22
- super().__init__(name=name, reduction=reduction)
23
- self.num_classes = num_classes
24
- self.lambda_coord = lambda_coord
25
- self.lambda_noobj = lambda_noobj
26
-
27
- def call(self, y_true, y_pred):
28
- obj_true = y_true[..., 0]
29
- box_true = y_true[..., 1:5]
30
- cls_true = y_true[..., 5:]
31
-
32
- obj_pred_logits = y_pred[..., 0]
33
- box_pred = y_pred[..., 1:5]
34
- cls_pred_logits = y_pred[..., 5:]
35
-
36
- obj_mask = tf.cast(obj_true > 0.5, tf.float32)
37
- noobj_mask = 1.0 - obj_mask
38
- num_pos = tf.maximum(tf.reduce_sum(obj_mask), 1.0)
39
-
40
- obj_loss_pos = obj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
41
- labels=obj_true, logits=obj_pred_logits)
42
- obj_loss_neg = noobj_mask * tf.nn.sigmoid_cross_entropy_with_logits(
43
- labels=obj_true, logits=obj_pred_logits)
44
- obj_loss = (tf.reduce_sum(obj_loss_pos) + self.lambda_noobj * tf.reduce_sum(obj_loss_neg)) / tf.cast(
45
- tf.size(obj_true), tf.float32)
46
-
47
- xy_pred = tf.nn.sigmoid(box_pred[..., 0:2])
48
- wh_pred = tf.nn.sigmoid(box_pred[..., 2:4])
49
- xy_true = box_true[..., 0:2]
50
- wh_true = box_true[..., 2:4]
51
-
52
- xy_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(xy_true - xy_pred)) / num_pos
53
- wh_loss = tf.reduce_sum(obj_mask[..., tf.newaxis] * self._smooth_l1_loss(wh_true - wh_pred)) / num_pos
54
- box_loss = self.lambda_coord * (xy_loss + wh_loss)
55
-
56
- cls_loss = tf.reduce_sum(obj_mask * tf.nn.softmax_cross_entropy_with_logits(
57
- labels=cls_true, logits=cls_pred_logits)) / num_pos
58
-
59
- total_loss = obj_loss + box_loss + cls_loss
60
- return tf.clip_by_value(total_loss, 0.0, 100.0)
61
-
62
- def _smooth_l1_loss(self, x, beta=1.0):
63
- abs_x = tf.abs(x)
64
- return tf.where(abs_x < beta, 0.5 * x * x / beta, abs_x - 0.5 * beta)
65
-
66
- def get_config(self):
67
- config = super().get_config()
68
- config.update({
69
- 'num_classes': self.num_classes,
70
- 'lambda_coord': self.lambda_coord,
71
- 'lambda_noobj': self.lambda_noobj,
72
- })
73
- return config
74
-
75
- @classmethod
76
- def from_config(cls, config):
77
- return cls(**config)
78
-
79
-
80
- # ============================================================================
81
- # CONFIGURATION - UPDATED FOR BETTER PRECISION
82
- # ============================================================================
83
- MODEL_PATH = "./wireframe_detection_model_best_700.keras"
84
  OUTPUT_DIR = "./output/"
85
  CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"]
86
 
@@ -88,27 +21,102 @@ IMG_SIZE = 416
88
  CONF_THRESHOLD = 0.1
89
  IOU_THRESHOLD = 0.1
90
 
91
- # Layout Configuration - INCREASED GRID DENSITY
92
- GRID_COLUMNS = 24 # Doubled from 12 for finer precision
93
- ALIGNMENT_THRESHOLD = 10 # Reduced from 15 for tighter alignment
94
- SIZE_CLUSTERING_THRESHOLD = 15 # Reduced from 20 for better size grouping
95
 
96
- # Standard sizes for each element type (relative units) - UPDATED FOR SMALLER BUTTONS/CHECKBOXES
97
  STANDARD_SIZES = {
98
- 'button': {'width': 2, 'height': 1}, # Smaller button (was 2x1, now in finer grid)
99
- 'checkbox': {'width': 1, 'height': 1}, # Keep small checkbox
100
- 'textfield': {'width': 5, 'height': 1}, # Adjusted for new grid
101
- 'text': {'width': 3, 'height': 1}, # Adjusted
102
- 'paragraph': {'width': 8, 'height': 2}, # Adjusted
103
- 'image': {'width': 4, 'height': 4}, # Adjusted
104
- 'navbar': {'width': 24, 'height': 1} # Full width in new grid
105
  }
106
 
107
- model = None
108
 
109
 
110
  # ============================================================================
111
- # DATA STRUCTURES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # ============================================================================
113
  @dataclass
114
  class Element:
@@ -139,13 +147,13 @@ class NormalizedElement:
139
 
140
 
141
  # ============================================================================
142
- # PREDICTION EXTRACTION
143
  # ============================================================================
144
  def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
145
- """Extract predictions from the model."""
146
- global model
147
- if model is None:
148
- raise ValueError("Model not loaded. Please load the model first.")
149
 
150
  # Load and preprocess image
151
  pil_img = Image.open(image_path).convert("RGB")
@@ -155,25 +163,28 @@ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
155
  img_array = np.array(resized_img, dtype=np.float32) / 255.0
156
  input_tensor = np.expand_dims(img_array, axis=0)
157
 
158
- # Get predictions
159
- pred_grid = model.predict(input_tensor, verbose=0)[0]
 
 
 
160
  raw_boxes = []
161
  S = pred_grid.shape[0]
162
  cell_size = 1.0 / S
163
 
164
  for row in range(S):
165
  for col in range(S):
166
- obj_score = float(tf.nn.sigmoid(pred_grid[row, col, 0]))
167
  if obj_score < CONF_THRESHOLD:
168
  continue
169
 
170
- x_offset = float(tf.nn.sigmoid(pred_grid[row, col, 1]))
171
- y_offset = float(tf.nn.sigmoid(pred_grid[row, col, 2]))
172
- width = float(tf.nn.sigmoid(pred_grid[row, col, 3]))
173
- height = float(tf.nn.sigmoid(pred_grid[row, col, 4]))
174
 
175
  class_logits = pred_grid[row, col, 5:]
176
- class_probs = tf.nn.softmax(class_logits).numpy()
177
  class_id = int(np.argmax(class_probs))
178
  class_conf = float(class_probs[class_id])
179
  final_score = obj_score * class_conf
@@ -191,25 +202,24 @@ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
191
  if x2 > x1 and y2 > y1:
192
  raw_boxes.append((class_id, final_score, x1, y1, x2, y2))
193
 
194
- # Apply NMS per class
195
  elements = []
196
  for class_id in range(len(CLASS_NAMES)):
197
  class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id]
198
  if not class_boxes:
199
  continue
200
 
201
- scores = [b[0] for b in class_boxes]
202
- boxes_xyxy = [[b[1], b[2], b[3], b[4]] for b in class_boxes]
203
 
204
- selected_indices = tf.image.non_max_suppression(
205
  boxes=boxes_xyxy,
206
  scores=scores,
207
- max_output_size=50,
208
  iou_threshold=IOU_THRESHOLD,
209
  score_threshold=CONF_THRESHOLD
210
  )
211
 
212
- for idx in selected_indices.numpy():
213
  score, x1, y1, x2, y2 = class_boxes[idx]
214
  elements.append(Element(
215
  label=CLASS_NAMES[class_id],
@@ -221,7 +231,7 @@ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
221
 
222
 
223
  # ============================================================================
224
- # ALIGNMENT DETECTION
225
  # ============================================================================
226
  class AlignmentDetector:
227
  """Detects alignment relationships between elements."""
@@ -333,7 +343,7 @@ class AlignmentDetector:
333
 
334
 
335
  # ============================================================================
336
- # SIZE NORMALIZATION - UPDATED TO RESPECT ACTUAL SIZES MORE
337
  # ============================================================================
338
  class SizeNormalizer:
339
  """Normalizes element sizes based on type and clustering."""
@@ -384,17 +394,13 @@ class SizeNormalizer:
384
  return clusters
385
 
386
  def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]:
387
- """Get normalized size for an element based on its cluster - PRESERVES ACTUAL SIZE BETTER."""
388
- # Use the actual detected size instead of aggressive averaging
389
- # Only normalize if there's a significant cluster
390
  if len(size_cluster) >= 3:
391
- # Use median instead of mean to avoid outliers
392
  widths = sorted([e.width for e in size_cluster])
393
  heights = sorted([e.height for e in size_cluster])
394
  median_width = widths[len(widths) // 2]
395
  median_height = heights[len(heights) // 2]
396
 
397
- # Only normalize if element is within 30% of median
398
  if abs(element.width - median_width) / median_width < 0.3:
399
  normalized_width = round(median_width)
400
  else:
@@ -405,7 +411,6 @@ class SizeNormalizer:
405
  else:
406
  normalized_height = round(element.height)
407
  else:
408
- # Small cluster - keep original size
409
  normalized_width = round(element.width)
410
  normalized_height = round(element.height)
411
 
@@ -413,7 +418,7 @@ class SizeNormalizer:
413
 
414
 
415
  # ============================================================================
416
- # GRID-BASED LAYOUT SYSTEM - UPDATED FOR FINER PRECISION
417
  # ============================================================================
418
  class GridLayoutSystem:
419
  """Grid-based layout system for precise positioning."""
@@ -432,44 +437,36 @@ class GridLayoutSystem:
432
  print(f"πŸ“ Cell size: {self.cell_width:.1f}px Γ— {self.cell_height:.1f}px")
433
 
434
  def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]:
435
- """Snap bounding box to grid - UPDATED TO PRESERVE ORIGINAL SIZE BETTER."""
436
  x1, y1, x2, y2 = bbox
437
  original_width = x2 - x1
438
  original_height = y2 - y1
439
 
440
- # Calculate center
441
  center_x = (x1 + x2) / 2
442
  center_y = (y1 + y2) / 2
443
 
444
- # Find nearest grid cell for center
445
  center_col = round(center_x / self.cell_width)
446
  center_row = round(center_y / self.cell_height)
447
 
448
  if preserve_size:
449
- # Calculate span based on actual size (don't force to standard)
450
  width_cells = max(1, round(original_width / self.cell_width))
451
  height_cells = max(1, round(original_height / self.cell_height))
452
  else:
453
- # Use standard size
454
  standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1})
455
  width_cells = max(1, round(original_width / self.cell_width))
456
  height_cells = max(1, round(original_height / self.cell_height))
457
 
458
- # Only adjust to standard if very close
459
  if abs(width_cells - standard['width']) <= 0.5:
460
  width_cells = standard['width']
461
  if abs(height_cells - standard['height']) <= 0.5:
462
  height_cells = standard['height']
463
 
464
- # Calculate start position (center the element)
465
  start_col = center_col - width_cells // 2
466
  start_row = center_row - height_cells // 2
467
 
468
- # Clamp to grid bounds
469
  start_col = max(0, min(start_col, self.num_columns - width_cells))
470
  start_row = max(0, min(start_row, self.num_rows - height_cells))
471
 
472
- # Convert back to pixels
473
  snapped_x1 = start_col * self.cell_width
474
  snapped_y1 = start_row * self.cell_height
475
  snapped_x2 = (start_col + width_cells) * self.cell_width
@@ -497,7 +494,7 @@ class GridLayoutSystem:
497
 
498
 
499
  # ============================================================================
500
- # OVERLAP DETECTION & RESOLUTION - UPDATED WITH BETTER STRATEGIES
501
  # ============================================================================
502
  class OverlapResolver:
503
  """Detects and resolves overlapping elements."""
@@ -506,7 +503,7 @@ class OverlapResolver:
506
  self.elements = elements
507
  self.img_width = img_width
508
  self.img_height = img_height
509
- self.overlap_threshold = 0.2 # Reduced from 0.3 - be more aggressive
510
 
511
  def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
512
  """Compute Intersection over Union between two bounding boxes."""
@@ -545,7 +542,7 @@ class OverlapResolver:
545
  return overlap_ratio1, overlap_ratio2
546
 
547
  def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]:
548
- """Resolve overlaps by adjusting element positions - IMPROVED ALGORITHM."""
549
  print("\nπŸ” Checking for overlaps...")
550
 
551
  overlaps = []
@@ -579,7 +576,6 @@ class OverlapResolver:
579
 
580
  print(f"⚠️ Found {len(overlaps)} overlapping element pairs")
581
 
582
- # Sort by overlap severity
583
  overlaps.sort(key=lambda x: x['overlap'], reverse=True)
584
 
585
  elements_to_remove = set()
@@ -595,7 +591,6 @@ class OverlapResolver:
595
  elem2 = overlap_info['elem2']
596
  overlap_ratio = overlap_info['overlap']
597
 
598
- # Strategy 1: Nearly complete overlap (>70%) - remove lower confidence
599
  if overlap_ratio > 0.7:
600
  if elem1.original.score < elem2.original.score:
601
  elements_to_remove.add(idx1)
@@ -606,13 +601,11 @@ class OverlapResolver:
606
  print(f" πŸ—‘οΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - "
607
  f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}")
608
 
609
- # Strategy 2: Significant overlap (40-70%) - try to separate
610
  elif overlap_ratio > 0.4:
611
  self._try_separate_elements(elem1, elem2, overlap_info)
612
  print(f" ↔️ Separating {elem1.original.label} and {elem2.original.label} "
613
  f"(overlap: {overlap_ratio * 100:.1f}%)")
614
 
615
- # Strategy 3: Moderate overlap (20-40%) - shrink slightly
616
  else:
617
  self._shrink_overlapping_edges(elem1, elem2, overlap_info)
618
  print(f" πŸ“ Shrinking {elem1.original.label} and {elem2.original.label} "
@@ -629,11 +622,10 @@ class OverlapResolver:
629
 
630
  def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement,
631
  overlap_info: Dict):
632
- """Try to separate two significantly overlapping elements - IMPROVED."""
633
  bbox1 = elem1.normalized_bbox
634
  bbox2 = elem2.normalized_bbox
635
 
636
- # Calculate overlap dimensions
637
  overlap_x1 = max(bbox1[0], bbox2[0])
638
  overlap_y1 = max(bbox1[1], bbox2[1])
639
  overlap_x2 = min(bbox1[2], bbox2[2])
@@ -642,45 +634,35 @@ class OverlapResolver:
642
  overlap_width = overlap_x2 - overlap_x1
643
  overlap_height = overlap_y2 - overlap_y1
644
 
645
- # Calculate centers
646
  center1_x = (bbox1[0] + bbox1[2]) / 2
647
  center1_y = (bbox1[1] + bbox1[3]) / 2
648
  center2_x = (bbox2[0] + bbox2[2]) / 2
649
  center2_y = (bbox2[1] + bbox2[3]) / 2
650
 
651
- # Determine separation direction
652
  dx = abs(center2_x - center1_x)
653
  dy = abs(center2_y - center1_y)
654
 
655
- # Add minimum gap
656
- min_gap = 3 # pixels
657
 
658
  if dx > dy:
659
- # Separate horizontally
660
  if center1_x < center2_x:
661
- # elem1 is left of elem2
662
  midpoint = (bbox1[2] + bbox2[0]) / 2
663
  bbox1[2] = midpoint - min_gap
664
  bbox2[0] = midpoint + min_gap
665
  else:
666
- # elem2 is left of elem1
667
  midpoint = (bbox2[2] + bbox1[0]) / 2
668
  bbox2[2] = midpoint - min_gap
669
  bbox1[0] = midpoint + min_gap
670
  else:
671
- # Separate vertically
672
  if center1_y < center2_y:
673
- # elem1 is above elem2
674
  midpoint = (bbox1[3] + bbox2[1]) / 2
675
  bbox1[3] = midpoint - min_gap
676
  bbox2[1] = midpoint + min_gap
677
  else:
678
- # elem2 is above elem1
679
  midpoint = (bbox2[3] + bbox1[1]) / 2
680
  bbox2[3] = midpoint - min_gap
681
  bbox1[1] = midpoint + min_gap
682
 
683
- # Ensure boxes remain valid
684
  self._ensure_valid_bbox(bbox1)
685
  self._ensure_valid_bbox(bbox2)
686
 
@@ -690,7 +672,6 @@ class OverlapResolver:
690
  bbox1 = elem1.normalized_bbox
691
  bbox2 = elem2.normalized_bbox
692
 
693
- # Calculate overlap region
694
  overlap_x1 = max(bbox1[0], bbox2[0])
695
  overlap_y1 = max(bbox1[1], bbox2[1])
696
  overlap_x2 = min(bbox1[2], bbox2[2])
@@ -699,11 +680,9 @@ class OverlapResolver:
699
  overlap_width = overlap_x2 - overlap_x1
700
  overlap_height = overlap_y2 - overlap_y1
701
 
702
- # Shrink by 50% of overlap plus small gap
703
- gap = 2 # pixels
704
 
705
  if overlap_width > overlap_height:
706
- # Horizontal overlap is larger
707
  shrink = overlap_width / 2 + gap
708
  if bbox1[0] < bbox2[0]:
709
  bbox1[2] -= shrink
@@ -712,7 +691,6 @@ class OverlapResolver:
712
  bbox2[2] -= shrink
713
  bbox1[0] += shrink
714
  else:
715
- # Vertical overlap is larger
716
  shrink = overlap_height / 2 + gap
717
  if bbox1[1] < bbox2[1]:
718
  bbox1[3] -= shrink
@@ -726,9 +704,8 @@ class OverlapResolver:
726
 
727
  def _ensure_valid_bbox(self, bbox: List[float]):
728
  """Ensure bounding box has minimum size and is within image bounds."""
729
- min_size = 8 # Reduced minimum size
730
 
731
- # Ensure minimum size
732
  if bbox[2] - bbox[0] < min_size:
733
  center_x = (bbox[0] + bbox[2]) / 2
734
  bbox[0] = center_x - min_size / 2
@@ -739,7 +716,6 @@ class OverlapResolver:
739
  bbox[1] = center_y - min_size / 2
740
  bbox[3] = center_y + min_size / 2
741
 
742
- # Clamp to image bounds
743
  bbox[0] = max(0, min(bbox[0], self.img_width))
744
  bbox[1] = max(0, min(bbox[1], self.img_height))
745
  bbox[2] = max(0, min(bbox[2], self.img_width))
@@ -747,7 +723,7 @@ class OverlapResolver:
747
 
748
 
749
  # ============================================================================
750
- # MAIN NORMALIZATION ENGINE
751
  # ============================================================================
752
  class LayoutNormalizer:
753
  """Main engine for normalizing wireframe layout."""
@@ -764,7 +740,6 @@ class LayoutNormalizer:
764
  """Normalize all elements with proper sizing and alignment."""
765
  print("\nπŸ”§ Starting layout normalization...")
766
 
767
- # Step 1: Detect alignments
768
  h_alignments = self.alignment_detector.detect_horizontal_alignments()
769
  v_alignments = self.alignment_detector.detect_vertical_alignments()
770
  edge_alignments = self.alignment_detector.detect_edge_alignments()
@@ -772,11 +747,9 @@ class LayoutNormalizer:
772
  print(f"βœ“ Found {len(h_alignments)} horizontal alignment groups")
773
  print(f"βœ“ Found {len(v_alignments)} vertical alignment groups")
774
 
775
- # Step 2: Cluster sizes by type
776
  size_clusters = self.size_normalizer.cluster_sizes_by_type()
777
  print(f"βœ“ Created size clusters for {len(size_clusters)} element types")
778
 
779
- # Step 3: Create element-to-cluster mapping
780
  element_to_cluster = {}
781
  element_to_size_category = {}
782
  for label, clusters in size_clusters.items():
@@ -786,18 +759,14 @@ class LayoutNormalizer:
786
  element_to_cluster[id(elem)] = cluster
787
  element_to_size_category[id(elem)] = category
788
 
789
- # Step 4: Normalize each element
790
  normalized_elements = []
791
 
792
  for elem in self.elements:
793
- # Get size cluster
794
  cluster = element_to_cluster.get(id(elem), [elem])
795
  size_category = element_to_size_category.get(id(elem), f"{elem.label}_default")
796
 
797
- # Get normalized size
798
  norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster)
799
 
800
- # Create normalized bbox (centered on original)
801
  center_x, center_y = elem.center_x, elem.center_y
802
  norm_bbox = [
803
  center_x - norm_width / 2,
@@ -806,7 +775,6 @@ class LayoutNormalizer:
806
  center_y + norm_height / 2
807
  ]
808
 
809
- # Snap to grid - preserve original size better
810
  snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True)
811
  grid_position = self.grid.get_grid_position(snapped_bbox)
812
 
@@ -817,12 +785,10 @@ class LayoutNormalizer:
817
  size_category=size_category
818
  ))
819
 
820
- # Step 5: Apply alignment corrections
821
  normalized_elements = self._apply_alignment_corrections(
822
  normalized_elements, h_alignments, v_alignments, edge_alignments
823
  )
824
 
825
- # Step 6: Resolve overlaps
826
  overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height)
827
  normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements)
828
 
@@ -835,32 +801,26 @@ class LayoutNormalizer:
835
  edge_alignments: Dict) -> List[NormalizedElement]:
836
  """Apply alignment corrections to normalized elements."""
837
 
838
- # Create lookup dictionary
839
  elem_to_normalized = {id(ne.original): ne for ne in normalized_elements}
840
 
841
- # Align horizontally grouped elements
842
  for h_group in h_alignments:
843
  norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized]
844
  if len(norm_group) > 1:
845
- # Align to average Y position
846
  avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group)
847
  for ne in norm_group:
848
  height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
849
  ne.normalized_bbox[1] = avg_y - height / 2
850
  ne.normalized_bbox[3] = avg_y + height / 2
851
 
852
- # Align vertically grouped elements
853
  for v_group in v_alignments:
854
  norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized]
855
  if len(norm_group) > 1:
856
- # Align to average X position
857
  avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group)
858
  for ne in norm_group:
859
  width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
860
  ne.normalized_bbox[0] = avg_x - width / 2
861
  ne.normalized_bbox[2] = avg_x + width / 2
862
 
863
- # Align edges
864
  for edge_type, groups in edge_alignments.items():
865
  for edge_group in groups:
866
  norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized]
@@ -894,7 +854,7 @@ class LayoutNormalizer:
894
 
895
 
896
  # ============================================================================
897
- # VISUALIZATION & EXPORT
898
  # ============================================================================
899
  def visualize_comparison(pil_img: Image.Image, elements: List[Element],
900
  normalized_elements: List[NormalizedElement],
@@ -903,7 +863,6 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
903
 
904
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
905
 
906
- # Original detections
907
  ax1.imshow(pil_img)
908
  ax1.set_title("Original Predictions", fontsize=16, weight='bold')
909
  ax1.axis('off')
@@ -918,12 +877,10 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
918
  ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8,
919
  bbox=dict(facecolor='white', alpha=0.7))
920
 
921
- # Normalized layout
922
  ax2.imshow(pil_img)
923
  ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold')
924
  ax2.axis('off')
925
 
926
- # Draw grid
927
  for x in range(grid_system.num_columns + 1):
928
  x_pos = x * grid_system.cell_width
929
  ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
@@ -931,7 +888,6 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
931
  y_pos = y * grid_system.cell_height
932
  ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
933
 
934
- # Draw normalized elements
935
  np.random.seed(42)
936
  colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))
937
  color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)}
@@ -940,14 +896,12 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
940
  x1, y1, x2, y2 = norm_elem.normalized_bbox
941
  color = color_map[norm_elem.original.label]
942
 
943
- # Normalized box (thick)
944
  rect = patches.Rectangle(
945
  (x1, y1), x2 - x1, y2 - y1,
946
  linewidth=3, edgecolor=color, facecolor='none'
947
  )
948
  ax2.add_patch(rect)
949
 
950
- # Original box (thin, dashed)
951
  ox1, oy1, ox2, oy2 = norm_elem.original.bbox
952
  orig_rect = patches.Rectangle(
953
  (ox1, oy1), ox2 - ox1, oy2 - oy1,
@@ -956,7 +910,6 @@ def visualize_comparison(pil_img: Image.Image, elements: List[Element],
956
  )
957
  ax2.add_patch(orig_rect)
958
 
959
- # Label
960
  grid_pos = norm_elem.grid_position
961
  label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}"
962
  ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7,
@@ -1086,7 +1039,6 @@ def export_to_html(normalized_elements: List[NormalizedElement],
1086
  text-transform: uppercase;
1087
  }}
1088
 
1089
- /* Element type specific styles */
1090
  .button {{
1091
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1092
  color: white;
@@ -1160,7 +1112,7 @@ def export_to_html(normalized_elements: List[NormalizedElement],
1160
  </head>
1161
  <body>
1162
  <div class="info-panel">
1163
- <h3>πŸ“ Layout Info</h3>
1164
  <p><strong>Grid:</strong> {grid_cols} Γ— {grid_rows}</p>
1165
  <p><strong>Elements:</strong> {total_elements}</p>
1166
  <p><strong>Dimensions:</strong> {img_width}px Γ— {img_height}px</p>
@@ -1208,7 +1160,7 @@ def export_to_html(normalized_elements: List[NormalizedElement],
1208
 
1209
 
1210
  # ============================================================================
1211
- # MAIN PIPELINE
1212
  # ============================================================================
1213
  def process_wireframe(image_path: str,
1214
  save_json: bool = True,
@@ -1230,51 +1182,46 @@ def process_wireframe(image_path: str,
1230
  print("=== PROCESS_WIREFRAME START ===")
1231
  print("Input image path:", image_path)
1232
  print("File exists:", os.path.exists(image_path))
1233
- print("File size:", os.path.getsize(image_path))
 
1234
 
1235
  print("=" * 80)
1236
- print("πŸš€ WIREFRAME LAYOUT NORMALIZER")
1237
  print("=" * 80)
1238
 
1239
- # Step 1: Load model and get predictions
1240
- global model
1241
- print("Model object is None?", model is None)
1242
- print("Model path exists?", os.path.exists(MODEL_PATH))
1243
- if model is None:
1244
- print("\nπŸ“¦ Loading model...")
1245
- print("Attempting to load keras model:", MODEL_PATH)
1246
- print("Loaded model summary:")
1247
- model.summary(print_fn=lambda x: print(x))
1248
  try:
1249
- model = tf.keras.models.load_model(
1250
- MODEL_PATH,
1251
- custom_objects={'LossCalculation': LossCalculation}
1252
- )
1253
- print("βœ… Model loaded successfully!")
 
1254
  except Exception as e:
1255
- print(f"❌ Error loading model: {e}")
1256
- print("\nTrying alternative loading method...")
1257
- try:
1258
- model = tf.keras.models.load_model(MODEL_PATH, compile=False)
1259
- print("βœ… Model loaded successfully (without compilation)!")
1260
- except Exception as e2:
1261
- print(f"❌ Failed to load model: {e2}")
1262
- return {}
1263
 
1264
  print(f"\nπŸ“Έ Processing image: {image_path}")
1265
  print("Running detection inference…")
1266
- print("Elements detected:", len(elements))
1267
- for elem in elements:
1268
- print(" -", elem.label, elem.score, elem.bbox)
1269
- pil_img, elements = get_predictions(image_path)
1270
- print(f"βœ… Detected {len(elements)} elements")
 
 
 
1271
 
1272
  if not elements:
1273
- print("⚠️ No detection output returned.")
1274
- print("β†’ Meaning model.predict returned zero raw boxes.")
1275
  print("β†’ Check thresholds:")
1276
- print("CONF_THRESHOLD:", CONF_THRESHOLD)
1277
- print("IOU_THRESHOLD:", IOU_THRESHOLD)
1278
  return {}
1279
 
1280
  # Step 2: Normalize layout
@@ -1314,7 +1261,6 @@ def process_wireframe(image_path: str,
1314
  print("πŸ“Š PROCESSING SUMMARY")
1315
  print("=" * 80)
1316
 
1317
- # Count by type
1318
  type_counts = {}
1319
  for elem in elements:
1320
  type_counts[elem.label] = type_counts.get(elem.label, 0) + 1
@@ -1323,18 +1269,16 @@ def process_wireframe(image_path: str,
1323
  for elem_type, count in sorted(type_counts.items()):
1324
  print(f" β€’ {elem_type}: {count}")
1325
 
1326
- # Size categories
1327
  size_categories = {}
1328
  for norm_elem in normalized_elements:
1329
  size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1
1330
 
1331
- print(f"\nπŸ“ Size Categories: {len(size_categories)}")
1332
 
1333
- # Alignment info
1334
  h_alignments = normalizer.alignment_detector.detect_horizontal_alignments()
1335
  v_alignments = normalizer.alignment_detector.detect_vertical_alignments()
1336
 
1337
- print(f"\nπŸ“ Alignment:")
1338
  print(f" β€’ Horizontal groups: {len(h_alignments)}")
1339
  print(f" β€’ Vertical groups: {len(v_alignments)}")
1340
 
@@ -1383,7 +1327,6 @@ def batch_process(image_dir: str, pattern: str = "*.png"):
1383
  'error': str(e)
1384
  })
1385
 
1386
- # Summary
1387
  successful = sum(1 for r in all_results if r['success'])
1388
  print(f"\n{'=' * 80}")
1389
  print(f"πŸ“Š BATCH PROCESSING COMPLETE")
 
1
+ import onnxruntime as ort
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib.patches as patches
 
11
 
12
 
13
  # ============================================================================
14
+ # CONFIGURATION - UPDATED FOR ONNX
15
  # ============================================================================
16
+ MODEL_PATH = "./wireframe_detection_model_best_700.onnx" # Changed to .onnx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  OUTPUT_DIR = "./output/"
18
  CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"]
19
 
 
21
  CONF_THRESHOLD = 0.1
22
  IOU_THRESHOLD = 0.1
23
 
24
+ # Layout Configuration
25
+ GRID_COLUMNS = 24
26
+ ALIGNMENT_THRESHOLD = 10
27
+ SIZE_CLUSTERING_THRESHOLD = 15
28
 
29
+ # Standard sizes for each element type (relative units)
30
  STANDARD_SIZES = {
31
+ 'button': {'width': 2, 'height': 1},
32
+ 'checkbox': {'width': 1, 'height': 1},
33
+ 'textfield': {'width': 5, 'height': 1},
34
+ 'text': {'width': 3, 'height': 1},
35
+ 'paragraph': {'width': 8, 'height': 2},
36
+ 'image': {'width': 4, 'height': 4},
37
+ 'navbar': {'width': 24, 'height': 1}
38
  }
39
 
40
+ ort_session = None # Changed from model to ort_session
41
 
42
 
43
  # ============================================================================
44
+ # UTILITY FUNCTIONS FOR ONNX
45
+ # ============================================================================
46
+ def sigmoid(x):
47
+ """Sigmoid activation function."""
48
+ return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
49
+
50
+
51
+ def softmax(x, axis=-1):
52
+ """Softmax activation function."""
53
+ exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
54
+ return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
55
+
56
+
57
+ def non_max_suppression_numpy(boxes, scores, iou_threshold=0.5, score_threshold=0.1):
58
+ """
59
+ Pure NumPy implementation of Non-Maximum Suppression.
60
+
61
+ Args:
62
+ boxes: Array of shape (N, 4) with format [x1, y1, x2, y2]
63
+ scores: Array of shape (N,) with confidence scores
64
+ iou_threshold: IoU threshold for suppression
65
+ score_threshold: Minimum score threshold
66
+
67
+ Returns:
68
+ List of indices to keep
69
+ """
70
+ if len(boxes) == 0:
71
+ return []
72
+
73
+ # Filter by score threshold
74
+ keep_mask = scores >= score_threshold
75
+ boxes = boxes[keep_mask]
76
+ scores = scores[keep_mask]
77
+
78
+ if len(boxes) == 0:
79
+ return []
80
+
81
+ # Get coordinates
82
+ x1 = boxes[:, 0]
83
+ y1 = boxes[:, 1]
84
+ x2 = boxes[:, 2]
85
+ y2 = boxes[:, 3]
86
+
87
+ # Calculate areas
88
+ areas = (x2 - x1) * (y2 - y1)
89
+
90
+ # Sort by scores
91
+ order = scores.argsort()[::-1]
92
+
93
+ keep = []
94
+ while order.size > 0:
95
+ # Pick the box with highest score
96
+ i = order[0]
97
+ keep.append(i)
98
+
99
+ # Calculate IoU with remaining boxes
100
+ xx1 = np.maximum(x1[i], x1[order[1:]])
101
+ yy1 = np.maximum(y1[i], y1[order[1:]])
102
+ xx2 = np.minimum(x2[i], x2[order[1:]])
103
+ yy2 = np.minimum(y2[i], y2[order[1:]])
104
+
105
+ w = np.maximum(0.0, xx2 - xx1)
106
+ h = np.maximum(0.0, yy2 - yy1)
107
+
108
+ intersection = w * h
109
+ iou = intersection / (areas[i] + areas[order[1:]] - intersection)
110
+
111
+ # Keep boxes with IoU less than threshold
112
+ inds = np.where(iou <= iou_threshold)[0]
113
+ order = order[inds + 1]
114
+
115
+ return keep
116
+
117
+
118
+ # ============================================================================
119
+ # DATA STRUCTURES (unchanged)
120
  # ============================================================================
121
  @dataclass
122
  class Element:
 
147
 
148
 
149
  # ============================================================================
150
+ # PREDICTION EXTRACTION - MODIFIED FOR ONNX
151
  # ============================================================================
152
  def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]:
153
+ """Extract predictions from the ONNX model."""
154
+ global ort_session
155
+ if ort_session is None:
156
+ raise ValueError("ONNX model not loaded. Please load the model first.")
157
 
158
  # Load and preprocess image
159
  pil_img = Image.open(image_path).convert("RGB")
 
163
  img_array = np.array(resized_img, dtype=np.float32) / 255.0
164
  input_tensor = np.expand_dims(img_array, axis=0)
165
 
166
+ # Get predictions from ONNX model
167
+ input_name = ort_session.get_inputs()[0].name
168
+ output_name = ort_session.get_outputs()[0].name
169
+ pred_grid = ort_session.run([output_name], {input_name: input_tensor})[0][0]
170
+
171
  raw_boxes = []
172
  S = pred_grid.shape[0]
173
  cell_size = 1.0 / S
174
 
175
  for row in range(S):
176
  for col in range(S):
177
+ obj_score = float(sigmoid(pred_grid[row, col, 0]))
178
  if obj_score < CONF_THRESHOLD:
179
  continue
180
 
181
+ x_offset = float(sigmoid(pred_grid[row, col, 1]))
182
+ y_offset = float(sigmoid(pred_grid[row, col, 2]))
183
+ width = float(sigmoid(pred_grid[row, col, 3]))
184
+ height = float(sigmoid(pred_grid[row, col, 4]))
185
 
186
  class_logits = pred_grid[row, col, 5:]
187
+ class_probs = softmax(class_logits)
188
  class_id = int(np.argmax(class_probs))
189
  class_conf = float(class_probs[class_id])
190
  final_score = obj_score * class_conf
 
202
  if x2 > x1 and y2 > y1:
203
  raw_boxes.append((class_id, final_score, x1, y1, x2, y2))
204
 
205
+ # Apply NMS per class using NumPy implementation
206
  elements = []
207
  for class_id in range(len(CLASS_NAMES)):
208
  class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id]
209
  if not class_boxes:
210
  continue
211
 
212
+ scores = np.array([b[0] for b in class_boxes])
213
+ boxes_xyxy = np.array([[b[1], b[2], b[3], b[4]] for b in class_boxes])
214
 
215
+ selected_indices = non_max_suppression_numpy(
216
  boxes=boxes_xyxy,
217
  scores=scores,
 
218
  iou_threshold=IOU_THRESHOLD,
219
  score_threshold=CONF_THRESHOLD
220
  )
221
 
222
+ for idx in selected_indices:
223
  score, x1, y1, x2, y2 = class_boxes[idx]
224
  elements.append(Element(
225
  label=CLASS_NAMES[class_id],
 
231
 
232
 
233
  # ============================================================================
234
+ # ALIGNMENT DETECTION (unchanged)
235
  # ============================================================================
236
  class AlignmentDetector:
237
  """Detects alignment relationships between elements."""
 
343
 
344
 
345
  # ============================================================================
346
+ # SIZE NORMALIZATION (unchanged)
347
  # ============================================================================
348
  class SizeNormalizer:
349
  """Normalizes element sizes based on type and clustering."""
 
394
  return clusters
395
 
396
  def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]:
397
+ """Get normalized size for an element based on its cluster."""
 
 
398
  if len(size_cluster) >= 3:
 
399
  widths = sorted([e.width for e in size_cluster])
400
  heights = sorted([e.height for e in size_cluster])
401
  median_width = widths[len(widths) // 2]
402
  median_height = heights[len(heights) // 2]
403
 
 
404
  if abs(element.width - median_width) / median_width < 0.3:
405
  normalized_width = round(median_width)
406
  else:
 
411
  else:
412
  normalized_height = round(element.height)
413
  else:
 
414
  normalized_width = round(element.width)
415
  normalized_height = round(element.height)
416
 
 
418
 
419
 
420
  # ============================================================================
421
+ # GRID-BASED LAYOUT SYSTEM (unchanged)
422
  # ============================================================================
423
  class GridLayoutSystem:
424
  """Grid-based layout system for precise positioning."""
 
437
  print(f"πŸ“ Cell size: {self.cell_width:.1f}px Γ— {self.cell_height:.1f}px")
438
 
439
  def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]:
440
+ """Snap bounding box to grid."""
441
  x1, y1, x2, y2 = bbox
442
  original_width = x2 - x1
443
  original_height = y2 - y1
444
 
 
445
  center_x = (x1 + x2) / 2
446
  center_y = (y1 + y2) / 2
447
 
 
448
  center_col = round(center_x / self.cell_width)
449
  center_row = round(center_y / self.cell_height)
450
 
451
  if preserve_size:
 
452
  width_cells = max(1, round(original_width / self.cell_width))
453
  height_cells = max(1, round(original_height / self.cell_height))
454
  else:
 
455
  standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1})
456
  width_cells = max(1, round(original_width / self.cell_width))
457
  height_cells = max(1, round(original_height / self.cell_height))
458
 
 
459
  if abs(width_cells - standard['width']) <= 0.5:
460
  width_cells = standard['width']
461
  if abs(height_cells - standard['height']) <= 0.5:
462
  height_cells = standard['height']
463
 
 
464
  start_col = center_col - width_cells // 2
465
  start_row = center_row - height_cells // 2
466
 
 
467
  start_col = max(0, min(start_col, self.num_columns - width_cells))
468
  start_row = max(0, min(start_row, self.num_rows - height_cells))
469
 
 
470
  snapped_x1 = start_col * self.cell_width
471
  snapped_y1 = start_row * self.cell_height
472
  snapped_x2 = (start_col + width_cells) * self.cell_width
 
494
 
495
 
496
  # ============================================================================
497
+ # OVERLAP DETECTION & RESOLUTION (unchanged)
498
  # ============================================================================
499
  class OverlapResolver:
500
  """Detects and resolves overlapping elements."""
 
503
  self.elements = elements
504
  self.img_width = img_width
505
  self.img_height = img_height
506
+ self.overlap_threshold = 0.2
507
 
508
  def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
509
  """Compute Intersection over Union between two bounding boxes."""
 
542
  return overlap_ratio1, overlap_ratio2
543
 
544
  def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]:
545
+ """Resolve overlaps by adjusting element positions."""
546
  print("\nπŸ” Checking for overlaps...")
547
 
548
  overlaps = []
 
576
 
577
  print(f"⚠️ Found {len(overlaps)} overlapping element pairs")
578
 
 
579
  overlaps.sort(key=lambda x: x['overlap'], reverse=True)
580
 
581
  elements_to_remove = set()
 
591
  elem2 = overlap_info['elem2']
592
  overlap_ratio = overlap_info['overlap']
593
 
 
594
  if overlap_ratio > 0.7:
595
  if elem1.original.score < elem2.original.score:
596
  elements_to_remove.add(idx1)
 
601
  print(f" πŸ—‘οΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - "
602
  f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}")
603
 
 
604
  elif overlap_ratio > 0.4:
605
  self._try_separate_elements(elem1, elem2, overlap_info)
606
  print(f" ↔️ Separating {elem1.original.label} and {elem2.original.label} "
607
  f"(overlap: {overlap_ratio * 100:.1f}%)")
608
 
 
609
  else:
610
  self._shrink_overlapping_edges(elem1, elem2, overlap_info)
611
  print(f" πŸ“ Shrinking {elem1.original.label} and {elem2.original.label} "
 
622
 
623
  def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement,
624
  overlap_info: Dict):
625
+ """Try to separate two significantly overlapping elements."""
626
  bbox1 = elem1.normalized_bbox
627
  bbox2 = elem2.normalized_bbox
628
 
 
629
  overlap_x1 = max(bbox1[0], bbox2[0])
630
  overlap_y1 = max(bbox1[1], bbox2[1])
631
  overlap_x2 = min(bbox1[2], bbox2[2])
 
634
  overlap_width = overlap_x2 - overlap_x1
635
  overlap_height = overlap_y2 - overlap_y1
636
 
 
637
  center1_x = (bbox1[0] + bbox1[2]) / 2
638
  center1_y = (bbox1[1] + bbox1[3]) / 2
639
  center2_x = (bbox2[0] + bbox2[2]) / 2
640
  center2_y = (bbox2[1] + bbox2[3]) / 2
641
 
 
642
  dx = abs(center2_x - center1_x)
643
  dy = abs(center2_y - center1_y)
644
 
645
+ min_gap = 3
 
646
 
647
  if dx > dy:
 
648
  if center1_x < center2_x:
 
649
  midpoint = (bbox1[2] + bbox2[0]) / 2
650
  bbox1[2] = midpoint - min_gap
651
  bbox2[0] = midpoint + min_gap
652
  else:
 
653
  midpoint = (bbox2[2] + bbox1[0]) / 2
654
  bbox2[2] = midpoint - min_gap
655
  bbox1[0] = midpoint + min_gap
656
  else:
 
657
  if center1_y < center2_y:
 
658
  midpoint = (bbox1[3] + bbox2[1]) / 2
659
  bbox1[3] = midpoint - min_gap
660
  bbox2[1] = midpoint + min_gap
661
  else:
 
662
  midpoint = (bbox2[3] + bbox1[1]) / 2
663
  bbox2[3] = midpoint - min_gap
664
  bbox1[1] = midpoint + min_gap
665
 
 
666
  self._ensure_valid_bbox(bbox1)
667
  self._ensure_valid_bbox(bbox2)
668
 
 
672
  bbox1 = elem1.normalized_bbox
673
  bbox2 = elem2.normalized_bbox
674
 
 
675
  overlap_x1 = max(bbox1[0], bbox2[0])
676
  overlap_y1 = max(bbox1[1], bbox2[1])
677
  overlap_x2 = min(bbox1[2], bbox2[2])
 
680
  overlap_width = overlap_x2 - overlap_x1
681
  overlap_height = overlap_y2 - overlap_y1
682
 
683
+ gap = 2
 
684
 
685
  if overlap_width > overlap_height:
 
686
  shrink = overlap_width / 2 + gap
687
  if bbox1[0] < bbox2[0]:
688
  bbox1[2] -= shrink
 
691
  bbox2[2] -= shrink
692
  bbox1[0] += shrink
693
  else:
 
694
  shrink = overlap_height / 2 + gap
695
  if bbox1[1] < bbox2[1]:
696
  bbox1[3] -= shrink
 
704
 
705
  def _ensure_valid_bbox(self, bbox: List[float]):
706
  """Ensure bounding box has minimum size and is within image bounds."""
707
+ min_size = 8
708
 
 
709
  if bbox[2] - bbox[0] < min_size:
710
  center_x = (bbox[0] + bbox[2]) / 2
711
  bbox[0] = center_x - min_size / 2
 
716
  bbox[1] = center_y - min_size / 2
717
  bbox[3] = center_y + min_size / 2
718
 
 
719
  bbox[0] = max(0, min(bbox[0], self.img_width))
720
  bbox[1] = max(0, min(bbox[1], self.img_height))
721
  bbox[2] = max(0, min(bbox[2], self.img_width))
 
723
 
724
 
725
  # ============================================================================
726
+ # MAIN NORMALIZATION ENGINE (unchanged)
727
  # ============================================================================
728
  class LayoutNormalizer:
729
  """Main engine for normalizing wireframe layout."""
 
740
  """Normalize all elements with proper sizing and alignment."""
741
  print("\nπŸ”§ Starting layout normalization...")
742
 
 
743
  h_alignments = self.alignment_detector.detect_horizontal_alignments()
744
  v_alignments = self.alignment_detector.detect_vertical_alignments()
745
  edge_alignments = self.alignment_detector.detect_edge_alignments()
 
747
  print(f"βœ“ Found {len(h_alignments)} horizontal alignment groups")
748
  print(f"βœ“ Found {len(v_alignments)} vertical alignment groups")
749
 
 
750
  size_clusters = self.size_normalizer.cluster_sizes_by_type()
751
  print(f"βœ“ Created size clusters for {len(size_clusters)} element types")
752
 
 
753
  element_to_cluster = {}
754
  element_to_size_category = {}
755
  for label, clusters in size_clusters.items():
 
759
  element_to_cluster[id(elem)] = cluster
760
  element_to_size_category[id(elem)] = category
761
 
 
762
  normalized_elements = []
763
 
764
  for elem in self.elements:
 
765
  cluster = element_to_cluster.get(id(elem), [elem])
766
  size_category = element_to_size_category.get(id(elem), f"{elem.label}_default")
767
 
 
768
  norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster)
769
 
 
770
  center_x, center_y = elem.center_x, elem.center_y
771
  norm_bbox = [
772
  center_x - norm_width / 2,
 
775
  center_y + norm_height / 2
776
  ]
777
 
 
778
  snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True)
779
  grid_position = self.grid.get_grid_position(snapped_bbox)
780
 
 
785
  size_category=size_category
786
  ))
787
 
 
788
  normalized_elements = self._apply_alignment_corrections(
789
  normalized_elements, h_alignments, v_alignments, edge_alignments
790
  )
791
 
 
792
  overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height)
793
  normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements)
794
 
 
801
  edge_alignments: Dict) -> List[NormalizedElement]:
802
  """Apply alignment corrections to normalized elements."""
803
 
 
804
  elem_to_normalized = {id(ne.original): ne for ne in normalized_elements}
805
 
 
806
  for h_group in h_alignments:
807
  norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized]
808
  if len(norm_group) > 1:
 
809
  avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group)
810
  for ne in norm_group:
811
  height = ne.normalized_bbox[3] - ne.normalized_bbox[1]
812
  ne.normalized_bbox[1] = avg_y - height / 2
813
  ne.normalized_bbox[3] = avg_y + height / 2
814
 
 
815
  for v_group in v_alignments:
816
  norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized]
817
  if len(norm_group) > 1:
 
818
  avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group)
819
  for ne in norm_group:
820
  width = ne.normalized_bbox[2] - ne.normalized_bbox[0]
821
  ne.normalized_bbox[0] = avg_x - width / 2
822
  ne.normalized_bbox[2] = avg_x + width / 2
823
 
 
824
  for edge_type, groups in edge_alignments.items():
825
  for edge_group in groups:
826
  norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized]
 
854
 
855
 
856
  # ============================================================================
857
+ # VISUALIZATION & EXPORT (unchanged)
858
  # ============================================================================
859
  def visualize_comparison(pil_img: Image.Image, elements: List[Element],
860
  normalized_elements: List[NormalizedElement],
 
863
 
864
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
865
 
 
866
  ax1.imshow(pil_img)
867
  ax1.set_title("Original Predictions", fontsize=16, weight='bold')
868
  ax1.axis('off')
 
877
  ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8,
878
  bbox=dict(facecolor='white', alpha=0.7))
879
 
 
880
  ax2.imshow(pil_img)
881
  ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold')
882
  ax2.axis('off')
883
 
 
884
  for x in range(grid_system.num_columns + 1):
885
  x_pos = x * grid_system.cell_width
886
  ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
 
888
  y_pos = y * grid_system.cell_height
889
  ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3)
890
 
 
891
  np.random.seed(42)
892
  colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))
893
  color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)}
 
896
  x1, y1, x2, y2 = norm_elem.normalized_bbox
897
  color = color_map[norm_elem.original.label]
898
 
 
899
  rect = patches.Rectangle(
900
  (x1, y1), x2 - x1, y2 - y1,
901
  linewidth=3, edgecolor=color, facecolor='none'
902
  )
903
  ax2.add_patch(rect)
904
 
 
905
  ox1, oy1, ox2, oy2 = norm_elem.original.bbox
906
  orig_rect = patches.Rectangle(
907
  (ox1, oy1), ox2 - ox1, oy2 - oy1,
 
910
  )
911
  ax2.add_patch(orig_rect)
912
 
 
913
  grid_pos = norm_elem.grid_position
914
  label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}"
915
  ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7,
 
1039
  text-transform: uppercase;
1040
  }}
1041
 
 
1042
  .button {{
1043
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1044
  color: white;
 
1112
  </head>
1113
  <body>
1114
  <div class="info-panel">
1115
+ <h3>πŸ“ Layout Info</h3>
1116
  <p><strong>Grid:</strong> {grid_cols} Γ— {grid_rows}</p>
1117
  <p><strong>Elements:</strong> {total_elements}</p>
1118
  <p><strong>Dimensions:</strong> {img_width}px Γ— {img_height}px</p>
 
1160
 
1161
 
1162
  # ============================================================================
1163
+ # MAIN PIPELINE - MODIFIED FOR ONNX
1164
  # ============================================================================
1165
  def process_wireframe(image_path: str,
1166
  save_json: bool = True,
 
1182
  print("=== PROCESS_WIREFRAME START ===")
1183
  print("Input image path:", image_path)
1184
  print("File exists:", os.path.exists(image_path))
1185
+ if os.path.exists(image_path):
1186
+ print("File size:", os.path.getsize(image_path))
1187
 
1188
  print("=" * 80)
1189
+ print("πŸš€ WIREFRAME LAYOUT NORMALIZER (ONNX)")
1190
  print("=" * 80)
1191
 
1192
+ # Step 1: Load ONNX model and get predictions
1193
+ global ort_session
1194
+ if ort_session is None:
1195
+ print("\nπŸ“¦ Loading ONNX model...")
1196
+ print("Model path:", MODEL_PATH)
1197
+ print("Model path exists?", os.path.exists(MODEL_PATH))
 
 
 
1198
  try:
1199
+ ort_session = ort.InferenceSession(MODEL_PATH)
1200
+ print("βœ… ONNX model loaded successfully!")
1201
+ print(f"Input name: {ort_session.get_inputs()[0].name}")
1202
+ print(f"Input shape: {ort_session.get_inputs()[0].shape}")
1203
+ print(f"Output name: {ort_session.get_outputs()[0].name}")
1204
+ print(f"Output shape: {ort_session.get_outputs()[0].shape}")
1205
  except Exception as e:
1206
+ print(f"❌ Error loading ONNX model: {e}")
1207
+ return {}
 
 
 
 
 
 
1208
 
1209
  print(f"\nπŸ“Έ Processing image: {image_path}")
1210
  print("Running detection inference…")
1211
+ try:
1212
+ pil_img, elements = get_predictions(image_path)
1213
+ print(f"βœ… Detected {len(elements)} elements")
1214
+ for elem in elements:
1215
+ print(f" - {elem.label} (conf: {elem.score:.3f}) at {elem.bbox}")
1216
+ except Exception as e:
1217
+ print(f"❌ Error during prediction: {e}")
1218
+ return {}
1219
 
1220
  if not elements:
1221
+ print("⚠️ No elements detected.")
 
1222
  print("β†’ Check thresholds:")
1223
+ print(f" CONF_THRESHOLD: {CONF_THRESHOLD}")
1224
+ print(f" IOU_THRESHOLD: {IOU_THRESHOLD}")
1225
  return {}
1226
 
1227
  # Step 2: Normalize layout
 
1261
  print("πŸ“Š PROCESSING SUMMARY")
1262
  print("=" * 80)
1263
 
 
1264
  type_counts = {}
1265
  for elem in elements:
1266
  type_counts[elem.label] = type_counts.get(elem.label, 0) + 1
 
1269
  for elem_type, count in sorted(type_counts.items()):
1270
  print(f" β€’ {elem_type}: {count}")
1271
 
 
1272
  size_categories = {}
1273
  for norm_elem in normalized_elements:
1274
  size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1
1275
 
1276
+ print(f"\nπŸ“ Size Categories: {len(size_categories)}")
1277
 
 
1278
  h_alignments = normalizer.alignment_detector.detect_horizontal_alignments()
1279
  v_alignments = normalizer.alignment_detector.detect_vertical_alignments()
1280
 
1281
+ print(f"\nπŸ“ Alignment:")
1282
  print(f" β€’ Horizontal groups: {len(h_alignments)}")
1283
  print(f" β€’ Vertical groups: {len(v_alignments)}")
1284
 
 
1327
  'error': str(e)
1328
  })
1329
 
 
1330
  successful = sum(1 for r in all_results if r['success'])
1331
  print(f"\n{'=' * 80}")
1332
  print(f"πŸ“Š BATCH PROCESSING COMPLETE")
wireframe_detection_model_best_700.keras β†’ wireframe.onnx RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5dbcc2a9f3222325ee66087ea94a3ac5b6c674844b5173aae30a9f4bf4290f63
3
- size 53257515
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c47e1f0f63b4a29dd146331c582860e5981ea0546119b79511a167e856a6277
3
+ size 17701338