import onnxruntime as ort import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image, ImageOps import json import os from collections import defaultdict from dataclasses import dataclass from typing import List, Tuple, Dict, Optional # ============================================================================ # CONFIGURATION - UPDATED FOR ONNX # ============================================================================ MODEL_PATH = "./wireframe_detection_model_best_700.onnx" # Changed to .onnx OUTPUT_DIR = "./output/" CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"] IMG_SIZE = 416 CONF_THRESHOLD = 0.1 IOU_THRESHOLD = 0.1 # Layout Configuration GRID_COLUMNS = 24 ALIGNMENT_THRESHOLD = 10 SIZE_CLUSTERING_THRESHOLD = 15 # Standard sizes for each element type (relative units) STANDARD_SIZES = { 'button': {'width': 2, 'height': 1}, 'checkbox': {'width': 1, 'height': 1}, 'textfield': {'width': 5, 'height': 1}, 'text': {'width': 3, 'height': 1}, 'paragraph': {'width': 8, 'height': 2}, 'image': {'width': 4, 'height': 4}, 'navbar': {'width': 24, 'height': 1} } ort_session = None # Changed from model to ort_session # ============================================================================ # UTILITY FUNCTIONS FOR ONNX # ============================================================================ def sigmoid(x): """Sigmoid activation function.""" return 1 / (1 + np.exp(-np.clip(x, -500, 500))) def softmax(x, axis=-1): """Softmax activation function.""" exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def non_max_suppression_numpy(boxes, scores, iou_threshold=0.5, score_threshold=0.1): """ Pure NumPy implementation of Non-Maximum Suppression. Args: boxes: Array of shape (N, 4) with format [x1, y1, x2, y2] scores: Array of shape (N,) with confidence scores iou_threshold: IoU threshold for suppression score_threshold: Minimum score threshold Returns: List of indices to keep """ if len(boxes) == 0: return [] # Filter by score threshold keep_mask = scores >= score_threshold boxes = boxes[keep_mask] scores = scores[keep_mask] if len(boxes) == 0: return [] # Get coordinates x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] # Calculate areas areas = (x2 - x1) * (y2 - y1) # Sort by scores order = scores.argsort()[::-1] keep = [] while order.size > 0: # Pick the box with highest score i = order[0] keep.append(i) # Calculate IoU with remaining boxes xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) intersection = w * h iou = intersection / (areas[i] + areas[order[1:]] - intersection) # Keep boxes with IoU less than threshold inds = np.where(iou <= iou_threshold)[0] order = order[inds + 1] return keep # ============================================================================ # DATA STRUCTURES (unchanged) # ============================================================================ @dataclass class Element: """Represents a detected UI element.""" label: str score: float bbox: List[float] # [x1, y1, x2, y2] width: float = 0 height: float = 0 center_x: float = 0 center_y: float = 0 def __post_init__(self): self.width = self.bbox[2] - self.bbox[0] self.height = self.bbox[3] - self.bbox[1] self.center_x = (self.bbox[0] + self.bbox[2]) / 2 self.center_y = (self.bbox[1] + self.bbox[3]) / 2 @dataclass class NormalizedElement: """Represents a normalized UI element.""" original: Element normalized_bbox: List[float] grid_position: Dict size_category: str alignment_group: Optional[int] = None # ============================================================================ # PREDICTION EXTRACTION - MODIFIED FOR ONNX # ============================================================================ def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]: """Extract predictions from the ONNX model.""" global ort_session if ort_session is None: raise ValueError("ONNX model not loaded. Please load the model first.") # Load and preprocess image pil_img = Image.open(image_path).convert("RGB") pil_img = ImageOps.exif_transpose(pil_img) orig_w, orig_h = pil_img.size resized_img = pil_img.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS) img_array = np.array(resized_img, dtype=np.float32) / 255.0 input_tensor = np.expand_dims(img_array, axis=0) # Get predictions from ONNX model input_name = ort_session.get_inputs()[0].name output_name = ort_session.get_outputs()[0].name pred_grid = ort_session.run([output_name], {input_name: input_tensor})[0][0] raw_boxes = [] S = pred_grid.shape[0] cell_size = 1.0 / S for row in range(S): for col in range(S): obj_score = float(sigmoid(pred_grid[row, col, 0])) if obj_score < CONF_THRESHOLD: continue x_offset = float(sigmoid(pred_grid[row, col, 1])) y_offset = float(sigmoid(pred_grid[row, col, 2])) width = float(sigmoid(pred_grid[row, col, 3])) height = float(sigmoid(pred_grid[row, col, 4])) class_logits = pred_grid[row, col, 5:] class_probs = softmax(class_logits) class_id = int(np.argmax(class_probs)) class_conf = float(class_probs[class_id]) final_score = obj_score * class_conf if final_score < CONF_THRESHOLD: continue center_x = (col + x_offset) * cell_size center_y = (row + y_offset) * cell_size x1 = (center_x - width / 2) * orig_w y1 = (center_y - height / 2) * orig_h x2 = (center_x + width / 2) * orig_w y2 = (center_y + height / 2) * orig_h if x2 > x1 and y2 > y1: raw_boxes.append((class_id, final_score, x1, y1, x2, y2)) # Apply NMS per class using NumPy implementation elements = [] for class_id in range(len(CLASS_NAMES)): class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id] if not class_boxes: continue scores = np.array([b[0] for b in class_boxes]) boxes_xyxy = np.array([[b[1], b[2], b[3], b[4]] for b in class_boxes]) selected_indices = non_max_suppression_numpy( boxes=boxes_xyxy, scores=scores, iou_threshold=IOU_THRESHOLD, score_threshold=CONF_THRESHOLD ) for idx in selected_indices: score, x1, y1, x2, y2 = class_boxes[idx] elements.append(Element( label=CLASS_NAMES[class_id], score=float(score), bbox=[float(x1), float(y1), float(x2), float(y2)] )) return pil_img, elements # ============================================================================ # ALIGNMENT DETECTION (unchanged) # ============================================================================ class AlignmentDetector: """Detects alignment relationships between elements.""" def __init__(self, elements: List[Element], threshold: float = ALIGNMENT_THRESHOLD): self.elements = elements self.threshold = threshold def detect_horizontal_alignments(self) -> List[List[Element]]: """Group elements that are horizontally aligned (same Y position).""" if not self.elements: return [] sorted_elements = sorted(self.elements, key=lambda e: e.center_y) groups = [] current_group = [sorted_elements[0]] for elem in sorted_elements[1:]: avg_y = sum(e.center_y for e in current_group) / len(current_group) if abs(elem.center_y - avg_y) <= self.threshold: current_group.append(elem) else: if len(current_group) > 1: current_group.sort(key=lambda e: e.center_x) groups.append(current_group) current_group = [elem] if len(current_group) > 1: current_group.sort(key=lambda e: e.center_x) groups.append(current_group) return groups def detect_vertical_alignments(self) -> List[List[Element]]: """Group elements that are vertically aligned (same X position).""" if not self.elements: return [] sorted_elements = sorted(self.elements, key=lambda e: e.center_x) groups = [] current_group = [sorted_elements[0]] for elem in sorted_elements[1:]: avg_x = sum(e.center_x for e in current_group) / len(current_group) if abs(elem.center_x - avg_x) <= self.threshold: current_group.append(elem) else: if len(current_group) > 1: current_group.sort(key=lambda e: e.center_y) groups.append(current_group) current_group = [elem] if len(current_group) > 1: current_group.sort(key=lambda e: e.center_y) groups.append(current_group) return groups def detect_edge_alignments(self) -> Dict[str, List[List[Element]]]: """Detect elements with aligned edges (left, right, top, bottom).""" alignments = { 'left': [], 'right': [], 'top': [], 'bottom': [] } if not self.elements: return alignments sorted_left = sorted(self.elements, key=lambda e: e.bbox[0]) alignments['left'] = self._cluster_by_value(sorted_left, lambda e: e.bbox[0]) sorted_right = sorted(self.elements, key=lambda e: e.bbox[2]) alignments['right'] = self._cluster_by_value(sorted_right, lambda e: e.bbox[2]) sorted_top = sorted(self.elements, key=lambda e: e.bbox[1]) alignments['top'] = self._cluster_by_value(sorted_top, lambda e: e.bbox[1]) sorted_bottom = sorted(self.elements, key=lambda e: e.bbox[3]) alignments['bottom'] = self._cluster_by_value(sorted_bottom, lambda e: e.bbox[3]) return alignments def _cluster_by_value(self, elements: List[Element], value_func) -> List[List[Element]]: """Cluster elements by a value function within threshold.""" if not elements: return [] groups = [] current_group = [elements[0]] current_value = value_func(elements[0]) for elem in elements[1:]: elem_value = value_func(elem) if abs(elem_value - current_value) <= self.threshold: current_group.append(elem) current_value = (current_value * (len(current_group) - 1) + elem_value) / len(current_group) else: if len(current_group) > 1: groups.append(current_group) current_group = [elem] current_value = elem_value if len(current_group) > 1: groups.append(current_group) return groups # ============================================================================ # SIZE NORMALIZATION (unchanged) # ============================================================================ class SizeNormalizer: """Normalizes element sizes based on type and clustering.""" def __init__(self, elements: List[Element], img_width: float, img_height: float): self.elements = elements self.img_width = img_width self.img_height = img_height self.size_clusters = {} def cluster_sizes_by_type(self) -> Dict[str, List[List[Element]]]: """Cluster elements of same type by similar sizes.""" clusters_by_type = {} for label in CLASS_NAMES: type_elements = [e for e in self.elements if e.label == label] if not type_elements: continue width_clusters = self._cluster_by_dimension(type_elements, 'width') final_clusters = [] for width_cluster in width_clusters: height_clusters = self._cluster_by_dimension(width_cluster, 'height') final_clusters.extend(height_clusters) clusters_by_type[label] = final_clusters return clusters_by_type def _cluster_by_dimension(self, elements: List[Element], dimension: str) -> List[List[Element]]: """Cluster elements by width or height.""" if not elements: return [] sorted_elements = sorted(elements, key=lambda e: getattr(e, dimension)) clusters = [] current_cluster = [sorted_elements[0]] for elem in sorted_elements[1:]: avg_dim = sum(getattr(e, dimension) for e in current_cluster) / len(current_cluster) if abs(getattr(elem, dimension) - avg_dim) <= SIZE_CLUSTERING_THRESHOLD: current_cluster.append(elem) else: clusters.append(current_cluster) current_cluster = [elem] clusters.append(current_cluster) return clusters def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]: """Get normalized size for an element based on its cluster.""" if len(size_cluster) >= 3: widths = sorted([e.width for e in size_cluster]) heights = sorted([e.height for e in size_cluster]) median_width = widths[len(widths) // 2] median_height = heights[len(heights) // 2] if abs(element.width - median_width) / median_width < 0.3: normalized_width = round(median_width) else: normalized_width = round(element.width) if abs(element.height - median_height) / median_height < 0.3: normalized_height = round(median_height) else: normalized_height = round(element.height) else: normalized_width = round(element.width) normalized_height = round(element.height) return normalized_width, normalized_height # ============================================================================ # GRID-BASED LAYOUT SYSTEM (unchanged) # ============================================================================ class GridLayoutSystem: """Grid-based layout system for precise positioning.""" def __init__(self, img_width: float, img_height: float, num_columns: int = GRID_COLUMNS): self.img_width = img_width self.img_height = img_height self.num_columns = num_columns cell_width = img_width / num_columns self.num_rows = max(1, int(img_height / cell_width)) self.cell_width = img_width / num_columns self.cell_height = img_height / self.num_rows print(f"π Grid System: {self.num_columns} columns Γ {self.num_rows} rows") print(f"π Cell size: {self.cell_width:.1f}px Γ {self.cell_height:.1f}px") def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]: """Snap bounding box to grid.""" x1, y1, x2, y2 = bbox original_width = x2 - x1 original_height = y2 - y1 center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 center_col = round(center_x / self.cell_width) center_row = round(center_y / self.cell_height) if preserve_size: width_cells = max(1, round(original_width / self.cell_width)) height_cells = max(1, round(original_height / self.cell_height)) else: standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1}) width_cells = max(1, round(original_width / self.cell_width)) height_cells = max(1, round(original_height / self.cell_height)) if abs(width_cells - standard['width']) <= 0.5: width_cells = standard['width'] if abs(height_cells - standard['height']) <= 0.5: height_cells = standard['height'] start_col = center_col - width_cells // 2 start_row = center_row - height_cells // 2 start_col = max(0, min(start_col, self.num_columns - width_cells)) start_row = max(0, min(start_row, self.num_rows - height_cells)) snapped_x1 = start_col * self.cell_width snapped_y1 = start_row * self.cell_height snapped_x2 = (start_col + width_cells) * self.cell_width snapped_y2 = (start_row + height_cells) * self.cell_height return [snapped_x1, snapped_y1, snapped_x2, snapped_y2] def get_grid_position(self, bbox: List[float]) -> Dict: """Get grid position information for a bounding box.""" x1, y1, x2, y2 = bbox start_col = int(x1 / self.cell_width) start_row = int(y1 / self.cell_height) end_col = int(np.ceil(x2 / self.cell_width)) end_row = int(np.ceil(y2 / self.cell_height)) return { 'start_row': start_row, 'end_row': end_row, 'start_col': start_col, 'end_col': end_col, 'rowspan': end_row - start_row, 'colspan': end_col - start_col } # ============================================================================ # OVERLAP DETECTION & RESOLUTION (unchanged) # ============================================================================ class OverlapResolver: """Detects and resolves overlapping elements.""" def __init__(self, elements: List[Element], img_width: float, img_height: float): self.elements = elements self.img_width = img_width self.img_height = img_height self.overlap_threshold = 0.2 def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float: """Compute Intersection over Union between two bounding boxes.""" x1 = max(bbox1[0], bbox2[0]) y1 = max(bbox1[1], bbox2[1]) x2 = min(bbox1[2], bbox2[2]) y2 = min(bbox1[3], bbox2[3]) if x2 <= x1 or y2 <= y1: return 0.0 intersection = (x2 - x1) * (y2 - y1) area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) union = area1 + area2 - intersection return intersection / union if union > 0 else 0.0 def compute_overlap_ratio(self, bbox1: List[float], bbox2: List[float]) -> Tuple[float, float]: """Compute what percentage of each box overlaps with the other.""" x1 = max(bbox1[0], bbox2[0]) y1 = max(bbox1[1], bbox2[1]) x2 = min(bbox1[2], bbox2[2]) y2 = min(bbox1[3], bbox2[3]) if x2 <= x1 or y2 <= y1: return 0.0, 0.0 intersection = (x2 - x1) * (y2 - y1) area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) overlap_ratio1 = intersection / area1 if area1 > 0 else 0.0 overlap_ratio2 = intersection / area2 if area2 > 0 else 0.0 return overlap_ratio1, overlap_ratio2 def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]: """Resolve overlaps by adjusting element positions.""" print("\nπ Checking for overlaps...") overlaps = [] for i in range(len(normalized_elements)): for j in range(i + 1, len(normalized_elements)): ne1 = normalized_elements[i] ne2 = normalized_elements[j] iou = self.compute_iou(ne1.normalized_bbox, ne2.normalized_bbox) if iou > 0: overlap1, overlap2 = self.compute_overlap_ratio( ne1.normalized_bbox, ne2.normalized_bbox ) max_overlap = max(overlap1, overlap2) if max_overlap >= self.overlap_threshold: overlaps.append({ 'idx1': i, 'idx2': j, 'elem1': ne1, 'elem2': ne2, 'overlap': max_overlap, 'overlap1': overlap1, 'overlap2': overlap2, 'iou': iou }) if not overlaps: print("β No significant overlaps detected") return normalized_elements print(f"β οΈ Found {len(overlaps)} overlapping element pairs") overlaps.sort(key=lambda x: x['overlap'], reverse=True) elements_to_remove = set() for overlap_info in overlaps: idx1 = overlap_info['idx1'] idx2 = overlap_info['idx2'] if idx1 in elements_to_remove or idx2 in elements_to_remove: continue elem1 = overlap_info['elem1'] elem2 = overlap_info['elem2'] overlap_ratio = overlap_info['overlap'] if overlap_ratio > 0.7: if elem1.original.score < elem2.original.score: elements_to_remove.add(idx1) print(f" ποΈ Removing {elem1.original.label} (conf: {elem1.original.score:.2f}) - " f"overlaps {overlap_ratio * 100:.1f}% with {elem2.original.label}") else: elements_to_remove.add(idx2) print(f" ποΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - " f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}") elif overlap_ratio > 0.4: self._try_separate_elements(elem1, elem2, overlap_info) print(f" βοΈ Separating {elem1.original.label} and {elem2.original.label} " f"(overlap: {overlap_ratio * 100:.1f}%)") else: self._shrink_overlapping_edges(elem1, elem2, overlap_info) print(f" π Shrinking {elem1.original.label} and {elem2.original.label} " f"(overlap: {overlap_ratio * 100:.1f}%)") if elements_to_remove: normalized_elements = [ ne for i, ne in enumerate(normalized_elements) if i not in elements_to_remove ] print(f"β Removed {len(elements_to_remove)} completely overlapping elements") return normalized_elements def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement, overlap_info: Dict): """Try to separate two significantly overlapping elements.""" bbox1 = elem1.normalized_bbox bbox2 = elem2.normalized_bbox overlap_x1 = max(bbox1[0], bbox2[0]) overlap_y1 = max(bbox1[1], bbox2[1]) overlap_x2 = min(bbox1[2], bbox2[2]) overlap_y2 = min(bbox1[3], bbox2[3]) overlap_width = overlap_x2 - overlap_x1 overlap_height = overlap_y2 - overlap_y1 center1_x = (bbox1[0] + bbox1[2]) / 2 center1_y = (bbox1[1] + bbox1[3]) / 2 center2_x = (bbox2[0] + bbox2[2]) / 2 center2_y = (bbox2[1] + bbox2[3]) / 2 dx = abs(center2_x - center1_x) dy = abs(center2_y - center1_y) min_gap = 3 if dx > dy: if center1_x < center2_x: midpoint = (bbox1[2] + bbox2[0]) / 2 bbox1[2] = midpoint - min_gap bbox2[0] = midpoint + min_gap else: midpoint = (bbox2[2] + bbox1[0]) / 2 bbox2[2] = midpoint - min_gap bbox1[0] = midpoint + min_gap else: if center1_y < center2_y: midpoint = (bbox1[3] + bbox2[1]) / 2 bbox1[3] = midpoint - min_gap bbox2[1] = midpoint + min_gap else: midpoint = (bbox2[3] + bbox1[1]) / 2 bbox2[3] = midpoint - min_gap bbox1[1] = midpoint + min_gap self._ensure_valid_bbox(bbox1) self._ensure_valid_bbox(bbox2) def _shrink_overlapping_edges(self, elem1: NormalizedElement, elem2: NormalizedElement, overlap_info: Dict): """Shrink overlapping edges for moderate overlaps.""" bbox1 = elem1.normalized_bbox bbox2 = elem2.normalized_bbox overlap_x1 = max(bbox1[0], bbox2[0]) overlap_y1 = max(bbox1[1], bbox2[1]) overlap_x2 = min(bbox1[2], bbox2[2]) overlap_y2 = min(bbox1[3], bbox2[3]) overlap_width = overlap_x2 - overlap_x1 overlap_height = overlap_y2 - overlap_y1 gap = 2 if overlap_width > overlap_height: shrink = overlap_width / 2 + gap if bbox1[0] < bbox2[0]: bbox1[2] -= shrink bbox2[0] += shrink else: bbox2[2] -= shrink bbox1[0] += shrink else: shrink = overlap_height / 2 + gap if bbox1[1] < bbox2[1]: bbox1[3] -= shrink bbox2[1] += shrink else: bbox2[3] -= shrink bbox1[1] += shrink self._ensure_valid_bbox(bbox1) self._ensure_valid_bbox(bbox2) def _ensure_valid_bbox(self, bbox: List[float]): """Ensure bounding box has minimum size and is within image bounds.""" min_size = 8 if bbox[2] - bbox[0] < min_size: center_x = (bbox[0] + bbox[2]) / 2 bbox[0] = center_x - min_size / 2 bbox[2] = center_x + min_size / 2 if bbox[3] - bbox[1] < min_size: center_y = (bbox[1] + bbox[3]) / 2 bbox[1] = center_y - min_size / 2 bbox[3] = center_y + min_size / 2 bbox[0] = max(0, min(bbox[0], self.img_width)) bbox[1] = max(0, min(bbox[1], self.img_height)) bbox[2] = max(0, min(bbox[2], self.img_width)) bbox[3] = max(0, min(bbox[3], self.img_height)) # ============================================================================ # MAIN NORMALIZATION ENGINE (unchanged) # ============================================================================ class LayoutNormalizer: """Main engine for normalizing wireframe layout.""" def __init__(self, elements: List[Element], img_width: float, img_height: float): self.elements = elements self.img_width = img_width self.img_height = img_height self.grid = GridLayoutSystem(img_width, img_height) self.alignment_detector = AlignmentDetector(elements) self.size_normalizer = SizeNormalizer(elements, img_width, img_height) def normalize_layout(self) -> List[NormalizedElement]: """Normalize all elements with proper sizing and alignment.""" print("\nπ§ Starting layout normalization...") h_alignments = self.alignment_detector.detect_horizontal_alignments() v_alignments = self.alignment_detector.detect_vertical_alignments() edge_alignments = self.alignment_detector.detect_edge_alignments() print(f"β Found {len(h_alignments)} horizontal alignment groups") print(f"β Found {len(v_alignments)} vertical alignment groups") size_clusters = self.size_normalizer.cluster_sizes_by_type() print(f"β Created size clusters for {len(size_clusters)} element types") element_to_cluster = {} element_to_size_category = {} for label, clusters in size_clusters.items(): for i, cluster in enumerate(clusters): category = f"{label}_size_{i + 1}" for elem in cluster: element_to_cluster[id(elem)] = cluster element_to_size_category[id(elem)] = category normalized_elements = [] for elem in self.elements: cluster = element_to_cluster.get(id(elem), [elem]) size_category = element_to_size_category.get(id(elem), f"{elem.label}_default") norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster) center_x, center_y = elem.center_x, elem.center_y norm_bbox = [ center_x - norm_width / 2, center_y - norm_height / 2, center_x + norm_width / 2, center_y + norm_height / 2 ] snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True) grid_position = self.grid.get_grid_position(snapped_bbox) normalized_elements.append(NormalizedElement( original=elem, normalized_bbox=snapped_bbox, grid_position=grid_position, size_category=size_category )) normalized_elements = self._apply_alignment_corrections( normalized_elements, h_alignments, v_alignments, edge_alignments ) overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height) normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements) print(f"β Normalized {len(normalized_elements)} elements") return normalized_elements def _apply_alignment_corrections(self, normalized_elements: List[NormalizedElement], h_alignments: List[List[Element]], v_alignments: List[List[Element]], edge_alignments: Dict) -> List[NormalizedElement]: """Apply alignment corrections to normalized elements.""" elem_to_normalized = {id(ne.original): ne for ne in normalized_elements} for h_group in h_alignments: norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized] if len(norm_group) > 1: avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group) for ne in norm_group: height = ne.normalized_bbox[3] - ne.normalized_bbox[1] ne.normalized_bbox[1] = avg_y - height / 2 ne.normalized_bbox[3] = avg_y + height / 2 for v_group in v_alignments: norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized] if len(norm_group) > 1: avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group) for ne in norm_group: width = ne.normalized_bbox[2] - ne.normalized_bbox[0] ne.normalized_bbox[0] = avg_x - width / 2 ne.normalized_bbox[2] = avg_x + width / 2 for edge_type, groups in edge_alignments.items(): for edge_group in groups: norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized] if len(norm_group) > 1: if edge_type == 'left': avg_left = sum(ne.normalized_bbox[0] for ne in norm_group) / len(norm_group) for ne in norm_group: width = ne.normalized_bbox[2] - ne.normalized_bbox[0] ne.normalized_bbox[0] = avg_left ne.normalized_bbox[2] = avg_left + width elif edge_type == 'right': avg_right = sum(ne.normalized_bbox[2] for ne in norm_group) / len(norm_group) for ne in norm_group: width = ne.normalized_bbox[2] - ne.normalized_bbox[0] ne.normalized_bbox[2] = avg_right ne.normalized_bbox[0] = avg_right - width elif edge_type == 'top': avg_top = sum(ne.normalized_bbox[1] for ne in norm_group) / len(norm_group) for ne in norm_group: height = ne.normalized_bbox[3] - ne.normalized_bbox[1] ne.normalized_bbox[1] = avg_top ne.normalized_bbox[3] = avg_top + height elif edge_type == 'bottom': avg_bottom = sum(ne.normalized_bbox[3] for ne in norm_group) / len(norm_group) for ne in norm_group: height = ne.normalized_bbox[3] - ne.normalized_bbox[1] ne.normalized_bbox[3] = avg_bottom ne.normalized_bbox[1] = avg_bottom - height return normalized_elements # ============================================================================ # VISUALIZATION & EXPORT (unchanged) # ============================================================================ def visualize_comparison(pil_img: Image.Image, elements: List[Element], normalized_elements: List[NormalizedElement], grid_system: GridLayoutSystem): """Visualize original vs normalized layout.""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) ax1.imshow(pil_img) ax1.set_title("Original Predictions", fontsize=16, weight='bold') ax1.axis('off') for elem in elements: x1, y1, x2, y2 = elem.bbox rect = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none' ) ax1.add_patch(rect) ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8, bbox=dict(facecolor='white', alpha=0.7)) ax2.imshow(pil_img) ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold') ax2.axis('off') for x in range(grid_system.num_columns + 1): x_pos = x * grid_system.cell_width ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3) for y in range(grid_system.num_rows + 1): y_pos = y * grid_system.cell_height ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3) np.random.seed(42) colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES))) color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)} for norm_elem in normalized_elements: x1, y1, x2, y2 = norm_elem.normalized_bbox color = color_map[norm_elem.original.label] rect = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=3, edgecolor=color, facecolor='none' ) ax2.add_patch(rect) ox1, oy1, ox2, oy2 = norm_elem.original.bbox orig_rect = patches.Rectangle( (ox1, oy1), ox2 - ox1, oy2 - oy1, linewidth=1, edgecolor='gray', facecolor='none', linestyle='--', alpha=0.5 ) ax2.add_patch(orig_rect) grid_pos = norm_elem.grid_position label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}" ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7, bbox=dict(facecolor=color, alpha=0.8, pad=2)) plt.tight_layout() plt.show() def export_to_json(normalized_elements: List[NormalizedElement], grid_system: GridLayoutSystem, output_path: str): """Export normalized layout to JSON.""" output = { 'metadata': { 'image_width': grid_system.img_width, 'image_height': grid_system.img_height, 'grid_system': { 'columns': grid_system.num_columns, 'rows': grid_system.num_rows, 'cell_width': round(grid_system.cell_width, 2), 'cell_height': round(grid_system.cell_height, 2) }, 'total_elements': len(normalized_elements) }, 'elements': [] } for i, norm_elem in enumerate(normalized_elements): orig = norm_elem.original norm_bbox = norm_elem.normalized_bbox element_data = { 'id': i, 'type': orig.label, 'confidence': round(orig.score, 3), 'size_category': norm_elem.size_category, 'original_bbox': { 'x1': round(orig.bbox[0], 2), 'y1': round(orig.bbox[1], 2), 'x2': round(orig.bbox[2], 2), 'y2': round(orig.bbox[3], 2), 'width': round(orig.width, 2), 'height': round(orig.height, 2) }, 'normalized_bbox': { 'x1': round(norm_bbox[0], 2), 'y1': round(norm_bbox[1], 2), 'x2': round(norm_bbox[2], 2), 'y2': round(norm_bbox[3], 2), 'width': round(norm_bbox[2] - norm_bbox[0], 2), 'height': round(norm_bbox[3] - norm_bbox[1], 2) }, 'grid_position': norm_elem.grid_position, 'percentage': { 'x1': round((norm_bbox[0] / grid_system.img_width) * 100, 2), 'y1': round((norm_bbox[1] / grid_system.img_height) * 100, 2), 'x2': round((norm_bbox[2] / grid_system.img_width) * 100, 2), 'y2': round((norm_bbox[3] / grid_system.img_height) * 100, 2) } } output['elements'].append(element_data) os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) with open(output_path, 'w') as f: json.dump(output, f, indent=2) print(f"\nβ Exported normalized layout to: {output_path}") def export_to_html(normalized_elements: List[NormalizedElement], grid_system: GridLayoutSystem, output_path: str): """Export normalized layout as responsive HTML/CSS.""" html_template = """
Grid: {grid_cols} Γ {grid_rows}
Elements: {total_elements}
Dimensions: {img_width}px Γ {img_height}px
Hover over elements to see details