Spaces:
Sleeping
Sleeping
| import onnxruntime as ort | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image, ImageOps | |
| import json | |
| import os | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from typing import List, Tuple, Dict, Optional | |
| # ============================================================================ | |
| # CONFIGURATION - UPDATED FOR ONNX | |
| # ============================================================================ | |
| MODEL_PATH = "./wireframe_detection_model_best_700.onnx" # Changed to .onnx | |
| OUTPUT_DIR = "./output/" | |
| CLASS_NAMES = ["button", "checkbox", "image", "navbar", "paragraph", "text", "textfield"] | |
| IMG_SIZE = 416 | |
| CONF_THRESHOLD = 0.1 | |
| IOU_THRESHOLD = 0.1 | |
| # Layout Configuration | |
| GRID_COLUMNS = 24 | |
| ALIGNMENT_THRESHOLD = 10 | |
| SIZE_CLUSTERING_THRESHOLD = 15 | |
| # Standard sizes for each element type (relative units) | |
| STANDARD_SIZES = { | |
| 'button': {'width': 2, 'height': 1}, | |
| 'checkbox': {'width': 1, 'height': 1}, | |
| 'textfield': {'width': 5, 'height': 1}, | |
| 'text': {'width': 3, 'height': 1}, | |
| 'paragraph': {'width': 8, 'height': 2}, | |
| 'image': {'width': 4, 'height': 4}, | |
| 'navbar': {'width': 24, 'height': 1} | |
| } | |
| ort_session = None # Changed from model to ort_session | |
| # ============================================================================ | |
| # UTILITY FUNCTIONS FOR ONNX | |
| # ============================================================================ | |
| def sigmoid(x): | |
| """Sigmoid activation function.""" | |
| return 1 / (1 + np.exp(-np.clip(x, -500, 500))) | |
| def softmax(x, axis=-1): | |
| """Softmax activation function.""" | |
| exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) | |
| return exp_x / np.sum(exp_x, axis=axis, keepdims=True) | |
| def non_max_suppression_numpy(boxes, scores, iou_threshold=0.5, score_threshold=0.1): | |
| """ | |
| Pure NumPy implementation of Non-Maximum Suppression. | |
| Args: | |
| boxes: Array of shape (N, 4) with format [x1, y1, x2, y2] | |
| scores: Array of shape (N,) with confidence scores | |
| iou_threshold: IoU threshold for suppression | |
| score_threshold: Minimum score threshold | |
| Returns: | |
| List of indices to keep | |
| """ | |
| if len(boxes) == 0: | |
| return [] | |
| # Filter by score threshold | |
| keep_mask = scores >= score_threshold | |
| boxes = boxes[keep_mask] | |
| scores = scores[keep_mask] | |
| if len(boxes) == 0: | |
| return [] | |
| # Get coordinates | |
| x1 = boxes[:, 0] | |
| y1 = boxes[:, 1] | |
| x2 = boxes[:, 2] | |
| y2 = boxes[:, 3] | |
| # Calculate areas | |
| areas = (x2 - x1) * (y2 - y1) | |
| # Sort by scores | |
| order = scores.argsort()[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| # Pick the box with highest score | |
| i = order[0] | |
| keep.append(i) | |
| # Calculate IoU with remaining boxes | |
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |
| w = np.maximum(0.0, xx2 - xx1) | |
| h = np.maximum(0.0, yy2 - yy1) | |
| intersection = w * h | |
| iou = intersection / (areas[i] + areas[order[1:]] - intersection) | |
| # Keep boxes with IoU less than threshold | |
| inds = np.where(iou <= iou_threshold)[0] | |
| order = order[inds + 1] | |
| return keep | |
| # ============================================================================ | |
| # DATA STRUCTURES (unchanged) | |
| # ============================================================================ | |
| class Element: | |
| """Represents a detected UI element.""" | |
| label: str | |
| score: float | |
| bbox: List[float] # [x1, y1, x2, y2] | |
| width: float = 0 | |
| height: float = 0 | |
| center_x: float = 0 | |
| center_y: float = 0 | |
| def __post_init__(self): | |
| self.width = self.bbox[2] - self.bbox[0] | |
| self.height = self.bbox[3] - self.bbox[1] | |
| self.center_x = (self.bbox[0] + self.bbox[2]) / 2 | |
| self.center_y = (self.bbox[1] + self.bbox[3]) / 2 | |
| class NormalizedElement: | |
| """Represents a normalized UI element.""" | |
| original: Element | |
| normalized_bbox: List[float] | |
| grid_position: Dict | |
| size_category: str | |
| alignment_group: Optional[int] = None | |
| # ============================================================================ | |
| # PREDICTION EXTRACTION - MODIFIED FOR ONNX | |
| # ============================================================================ | |
| def get_predictions(image_path: str) -> Tuple[Image.Image, List[Element]]: | |
| """Extract predictions from the ONNX model.""" | |
| global ort_session | |
| if ort_session is None: | |
| raise ValueError("ONNX model not loaded. Please load the model first.") | |
| # Load and preprocess image | |
| pil_img = Image.open(image_path).convert("RGB") | |
| pil_img = ImageOps.exif_transpose(pil_img) | |
| orig_w, orig_h = pil_img.size | |
| resized_img = pil_img.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS) | |
| img_array = np.array(resized_img, dtype=np.float32) / 255.0 | |
| input_tensor = np.expand_dims(img_array, axis=0) | |
| # Get predictions from ONNX model | |
| input_name = ort_session.get_inputs()[0].name | |
| output_name = ort_session.get_outputs()[0].name | |
| pred_grid = ort_session.run([output_name], {input_name: input_tensor})[0][0] | |
| raw_boxes = [] | |
| S = pred_grid.shape[0] | |
| cell_size = 1.0 / S | |
| for row in range(S): | |
| for col in range(S): | |
| obj_score = float(sigmoid(pred_grid[row, col, 0])) | |
| if obj_score < CONF_THRESHOLD: | |
| continue | |
| x_offset = float(sigmoid(pred_grid[row, col, 1])) | |
| y_offset = float(sigmoid(pred_grid[row, col, 2])) | |
| width = float(sigmoid(pred_grid[row, col, 3])) | |
| height = float(sigmoid(pred_grid[row, col, 4])) | |
| class_logits = pred_grid[row, col, 5:] | |
| class_probs = softmax(class_logits) | |
| class_id = int(np.argmax(class_probs)) | |
| class_conf = float(class_probs[class_id]) | |
| final_score = obj_score * class_conf | |
| if final_score < CONF_THRESHOLD: | |
| continue | |
| center_x = (col + x_offset) * cell_size | |
| center_y = (row + y_offset) * cell_size | |
| x1 = (center_x - width / 2) * orig_w | |
| y1 = (center_y - height / 2) * orig_h | |
| x2 = (center_x + width / 2) * orig_w | |
| y2 = (center_y + height / 2) * orig_h | |
| if x2 > x1 and y2 > y1: | |
| raw_boxes.append((class_id, final_score, x1, y1, x2, y2)) | |
| # Apply NMS per class using NumPy implementation | |
| elements = [] | |
| for class_id in range(len(CLASS_NAMES)): | |
| class_boxes = [(score, x1, y1, x2, y2) for cid, score, x1, y1, x2, y2 in raw_boxes if cid == class_id] | |
| if not class_boxes: | |
| continue | |
| scores = np.array([b[0] for b in class_boxes]) | |
| boxes_xyxy = np.array([[b[1], b[2], b[3], b[4]] for b in class_boxes]) | |
| selected_indices = non_max_suppression_numpy( | |
| boxes=boxes_xyxy, | |
| scores=scores, | |
| iou_threshold=IOU_THRESHOLD, | |
| score_threshold=CONF_THRESHOLD | |
| ) | |
| for idx in selected_indices: | |
| score, x1, y1, x2, y2 = class_boxes[idx] | |
| elements.append(Element( | |
| label=CLASS_NAMES[class_id], | |
| score=float(score), | |
| bbox=[float(x1), float(y1), float(x2), float(y2)] | |
| )) | |
| return pil_img, elements | |
| # ============================================================================ | |
| # ALIGNMENT DETECTION (unchanged) | |
| # ============================================================================ | |
| class AlignmentDetector: | |
| """Detects alignment relationships between elements.""" | |
| def __init__(self, elements: List[Element], threshold: float = ALIGNMENT_THRESHOLD): | |
| self.elements = elements | |
| self.threshold = threshold | |
| def detect_horizontal_alignments(self) -> List[List[Element]]: | |
| """Group elements that are horizontally aligned (same Y position).""" | |
| if not self.elements: | |
| return [] | |
| sorted_elements = sorted(self.elements, key=lambda e: e.center_y) | |
| groups = [] | |
| current_group = [sorted_elements[0]] | |
| for elem in sorted_elements[1:]: | |
| avg_y = sum(e.center_y for e in current_group) / len(current_group) | |
| if abs(elem.center_y - avg_y) <= self.threshold: | |
| current_group.append(elem) | |
| else: | |
| if len(current_group) > 1: | |
| current_group.sort(key=lambda e: e.center_x) | |
| groups.append(current_group) | |
| current_group = [elem] | |
| if len(current_group) > 1: | |
| current_group.sort(key=lambda e: e.center_x) | |
| groups.append(current_group) | |
| return groups | |
| def detect_vertical_alignments(self) -> List[List[Element]]: | |
| """Group elements that are vertically aligned (same X position).""" | |
| if not self.elements: | |
| return [] | |
| sorted_elements = sorted(self.elements, key=lambda e: e.center_x) | |
| groups = [] | |
| current_group = [sorted_elements[0]] | |
| for elem in sorted_elements[1:]: | |
| avg_x = sum(e.center_x for e in current_group) / len(current_group) | |
| if abs(elem.center_x - avg_x) <= self.threshold: | |
| current_group.append(elem) | |
| else: | |
| if len(current_group) > 1: | |
| current_group.sort(key=lambda e: e.center_y) | |
| groups.append(current_group) | |
| current_group = [elem] | |
| if len(current_group) > 1: | |
| current_group.sort(key=lambda e: e.center_y) | |
| groups.append(current_group) | |
| return groups | |
| def detect_edge_alignments(self) -> Dict[str, List[List[Element]]]: | |
| """Detect elements with aligned edges (left, right, top, bottom).""" | |
| alignments = { | |
| 'left': [], | |
| 'right': [], | |
| 'top': [], | |
| 'bottom': [] | |
| } | |
| if not self.elements: | |
| return alignments | |
| sorted_left = sorted(self.elements, key=lambda e: e.bbox[0]) | |
| alignments['left'] = self._cluster_by_value(sorted_left, lambda e: e.bbox[0]) | |
| sorted_right = sorted(self.elements, key=lambda e: e.bbox[2]) | |
| alignments['right'] = self._cluster_by_value(sorted_right, lambda e: e.bbox[2]) | |
| sorted_top = sorted(self.elements, key=lambda e: e.bbox[1]) | |
| alignments['top'] = self._cluster_by_value(sorted_top, lambda e: e.bbox[1]) | |
| sorted_bottom = sorted(self.elements, key=lambda e: e.bbox[3]) | |
| alignments['bottom'] = self._cluster_by_value(sorted_bottom, lambda e: e.bbox[3]) | |
| return alignments | |
| def _cluster_by_value(self, elements: List[Element], value_func) -> List[List[Element]]: | |
| """Cluster elements by a value function within threshold.""" | |
| if not elements: | |
| return [] | |
| groups = [] | |
| current_group = [elements[0]] | |
| current_value = value_func(elements[0]) | |
| for elem in elements[1:]: | |
| elem_value = value_func(elem) | |
| if abs(elem_value - current_value) <= self.threshold: | |
| current_group.append(elem) | |
| current_value = (current_value * (len(current_group) - 1) + elem_value) / len(current_group) | |
| else: | |
| if len(current_group) > 1: | |
| groups.append(current_group) | |
| current_group = [elem] | |
| current_value = elem_value | |
| if len(current_group) > 1: | |
| groups.append(current_group) | |
| return groups | |
| # ============================================================================ | |
| # SIZE NORMALIZATION (unchanged) | |
| # ============================================================================ | |
| class SizeNormalizer: | |
| """Normalizes element sizes based on type and clustering.""" | |
| def __init__(self, elements: List[Element], img_width: float, img_height: float): | |
| self.elements = elements | |
| self.img_width = img_width | |
| self.img_height = img_height | |
| self.size_clusters = {} | |
| def cluster_sizes_by_type(self) -> Dict[str, List[List[Element]]]: | |
| """Cluster elements of same type by similar sizes.""" | |
| clusters_by_type = {} | |
| for label in CLASS_NAMES: | |
| type_elements = [e for e in self.elements if e.label == label] | |
| if not type_elements: | |
| continue | |
| width_clusters = self._cluster_by_dimension(type_elements, 'width') | |
| final_clusters = [] | |
| for width_cluster in width_clusters: | |
| height_clusters = self._cluster_by_dimension(width_cluster, 'height') | |
| final_clusters.extend(height_clusters) | |
| clusters_by_type[label] = final_clusters | |
| return clusters_by_type | |
| def _cluster_by_dimension(self, elements: List[Element], dimension: str) -> List[List[Element]]: | |
| """Cluster elements by width or height.""" | |
| if not elements: | |
| return [] | |
| sorted_elements = sorted(elements, key=lambda e: getattr(e, dimension)) | |
| clusters = [] | |
| current_cluster = [sorted_elements[0]] | |
| for elem in sorted_elements[1:]: | |
| avg_dim = sum(getattr(e, dimension) for e in current_cluster) / len(current_cluster) | |
| if abs(getattr(elem, dimension) - avg_dim) <= SIZE_CLUSTERING_THRESHOLD: | |
| current_cluster.append(elem) | |
| else: | |
| clusters.append(current_cluster) | |
| current_cluster = [elem] | |
| clusters.append(current_cluster) | |
| return clusters | |
| def get_normalized_size(self, element: Element, size_cluster: List[Element]) -> Tuple[float, float]: | |
| """Get normalized size for an element based on its cluster.""" | |
| if len(size_cluster) >= 3: | |
| widths = sorted([e.width for e in size_cluster]) | |
| heights = sorted([e.height for e in size_cluster]) | |
| median_width = widths[len(widths) // 2] | |
| median_height = heights[len(heights) // 2] | |
| if abs(element.width - median_width) / median_width < 0.3: | |
| normalized_width = round(median_width) | |
| else: | |
| normalized_width = round(element.width) | |
| if abs(element.height - median_height) / median_height < 0.3: | |
| normalized_height = round(median_height) | |
| else: | |
| normalized_height = round(element.height) | |
| else: | |
| normalized_width = round(element.width) | |
| normalized_height = round(element.height) | |
| return normalized_width, normalized_height | |
| # ============================================================================ | |
| # GRID-BASED LAYOUT SYSTEM (unchanged) | |
| # ============================================================================ | |
| class GridLayoutSystem: | |
| """Grid-based layout system for precise positioning.""" | |
| def __init__(self, img_width: float, img_height: float, num_columns: int = GRID_COLUMNS): | |
| self.img_width = img_width | |
| self.img_height = img_height | |
| self.num_columns = num_columns | |
| cell_width = img_width / num_columns | |
| self.num_rows = max(1, int(img_height / cell_width)) | |
| self.cell_width = img_width / num_columns | |
| self.cell_height = img_height / self.num_rows | |
| print(f"π Grid System: {self.num_columns} columns Γ {self.num_rows} rows") | |
| print(f"π Cell size: {self.cell_width:.1f}px Γ {self.cell_height:.1f}px") | |
| def snap_to_grid(self, bbox: List[float], element_label: str, preserve_size: bool = True) -> List[float]: | |
| """Snap bounding box to grid.""" | |
| x1, y1, x2, y2 = bbox | |
| original_width = x2 - x1 | |
| original_height = y2 - y1 | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| center_col = round(center_x / self.cell_width) | |
| center_row = round(center_y / self.cell_height) | |
| if preserve_size: | |
| width_cells = max(1, round(original_width / self.cell_width)) | |
| height_cells = max(1, round(original_height / self.cell_height)) | |
| else: | |
| standard = STANDARD_SIZES.get(element_label, {'width': 2, 'height': 1}) | |
| width_cells = max(1, round(original_width / self.cell_width)) | |
| height_cells = max(1, round(original_height / self.cell_height)) | |
| if abs(width_cells - standard['width']) <= 0.5: | |
| width_cells = standard['width'] | |
| if abs(height_cells - standard['height']) <= 0.5: | |
| height_cells = standard['height'] | |
| start_col = center_col - width_cells // 2 | |
| start_row = center_row - height_cells // 2 | |
| start_col = max(0, min(start_col, self.num_columns - width_cells)) | |
| start_row = max(0, min(start_row, self.num_rows - height_cells)) | |
| snapped_x1 = start_col * self.cell_width | |
| snapped_y1 = start_row * self.cell_height | |
| snapped_x2 = (start_col + width_cells) * self.cell_width | |
| snapped_y2 = (start_row + height_cells) * self.cell_height | |
| return [snapped_x1, snapped_y1, snapped_x2, snapped_y2] | |
| def get_grid_position(self, bbox: List[float]) -> Dict: | |
| """Get grid position information for a bounding box.""" | |
| x1, y1, x2, y2 = bbox | |
| start_col = int(x1 / self.cell_width) | |
| start_row = int(y1 / self.cell_height) | |
| end_col = int(np.ceil(x2 / self.cell_width)) | |
| end_row = int(np.ceil(y2 / self.cell_height)) | |
| return { | |
| 'start_row': start_row, | |
| 'end_row': end_row, | |
| 'start_col': start_col, | |
| 'end_col': end_col, | |
| 'rowspan': end_row - start_row, | |
| 'colspan': end_col - start_col | |
| } | |
| # ============================================================================ | |
| # OVERLAP DETECTION & RESOLUTION (unchanged) | |
| # ============================================================================ | |
| class OverlapResolver: | |
| """Detects and resolves overlapping elements.""" | |
| def __init__(self, elements: List[Element], img_width: float, img_height: float): | |
| self.elements = elements | |
| self.img_width = img_width | |
| self.img_height = img_height | |
| self.overlap_threshold = 0.2 | |
| def compute_iou(self, bbox1: List[float], bbox2: List[float]) -> float: | |
| """Compute Intersection over Union between two bounding boxes.""" | |
| x1 = max(bbox1[0], bbox2[0]) | |
| y1 = max(bbox1[1], bbox2[1]) | |
| x2 = min(bbox1[2], bbox2[2]) | |
| y2 = min(bbox1[3], bbox2[3]) | |
| if x2 <= x1 or y2 <= y1: | |
| return 0.0 | |
| intersection = (x2 - x1) * (y2 - y1) | |
| area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) | |
| area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def compute_overlap_ratio(self, bbox1: List[float], bbox2: List[float]) -> Tuple[float, float]: | |
| """Compute what percentage of each box overlaps with the other.""" | |
| x1 = max(bbox1[0], bbox2[0]) | |
| y1 = max(bbox1[1], bbox2[1]) | |
| x2 = min(bbox1[2], bbox2[2]) | |
| y2 = min(bbox1[3], bbox2[3]) | |
| if x2 <= x1 or y2 <= y1: | |
| return 0.0, 0.0 | |
| intersection = (x2 - x1) * (y2 - y1) | |
| area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) | |
| area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) | |
| overlap_ratio1 = intersection / area1 if area1 > 0 else 0.0 | |
| overlap_ratio2 = intersection / area2 if area2 > 0 else 0.0 | |
| return overlap_ratio1, overlap_ratio2 | |
| def resolve_overlaps(self, normalized_elements: List[NormalizedElement]) -> List[NormalizedElement]: | |
| """Resolve overlaps by adjusting element positions.""" | |
| print("\nπ Checking for overlaps...") | |
| overlaps = [] | |
| for i in range(len(normalized_elements)): | |
| for j in range(i + 1, len(normalized_elements)): | |
| ne1 = normalized_elements[i] | |
| ne2 = normalized_elements[j] | |
| iou = self.compute_iou(ne1.normalized_bbox, ne2.normalized_bbox) | |
| if iou > 0: | |
| overlap1, overlap2 = self.compute_overlap_ratio( | |
| ne1.normalized_bbox, ne2.normalized_bbox | |
| ) | |
| max_overlap = max(overlap1, overlap2) | |
| if max_overlap >= self.overlap_threshold: | |
| overlaps.append({ | |
| 'idx1': i, | |
| 'idx2': j, | |
| 'elem1': ne1, | |
| 'elem2': ne2, | |
| 'overlap': max_overlap, | |
| 'overlap1': overlap1, | |
| 'overlap2': overlap2, | |
| 'iou': iou | |
| }) | |
| if not overlaps: | |
| print("β No significant overlaps detected") | |
| return normalized_elements | |
| print(f"β οΈ Found {len(overlaps)} overlapping element pairs") | |
| overlaps.sort(key=lambda x: x['overlap'], reverse=True) | |
| elements_to_remove = set() | |
| for overlap_info in overlaps: | |
| idx1 = overlap_info['idx1'] | |
| idx2 = overlap_info['idx2'] | |
| if idx1 in elements_to_remove or idx2 in elements_to_remove: | |
| continue | |
| elem1 = overlap_info['elem1'] | |
| elem2 = overlap_info['elem2'] | |
| overlap_ratio = overlap_info['overlap'] | |
| if overlap_ratio > 0.7: | |
| if elem1.original.score < elem2.original.score: | |
| elements_to_remove.add(idx1) | |
| print(f" ποΈ Removing {elem1.original.label} (conf: {elem1.original.score:.2f}) - " | |
| f"overlaps {overlap_ratio * 100:.1f}% with {elem2.original.label}") | |
| else: | |
| elements_to_remove.add(idx2) | |
| print(f" ποΈ Removing {elem2.original.label} (conf: {elem2.original.score:.2f}) - " | |
| f"overlaps {overlap_ratio * 100:.1f}% with {elem1.original.label}") | |
| elif overlap_ratio > 0.4: | |
| self._try_separate_elements(elem1, elem2, overlap_info) | |
| print(f" βοΈ Separating {elem1.original.label} and {elem2.original.label} " | |
| f"(overlap: {overlap_ratio * 100:.1f}%)") | |
| else: | |
| self._shrink_overlapping_edges(elem1, elem2, overlap_info) | |
| print(f" π Shrinking {elem1.original.label} and {elem2.original.label} " | |
| f"(overlap: {overlap_ratio * 100:.1f}%)") | |
| if elements_to_remove: | |
| normalized_elements = [ | |
| ne for i, ne in enumerate(normalized_elements) | |
| if i not in elements_to_remove | |
| ] | |
| print(f"β Removed {len(elements_to_remove)} completely overlapping elements") | |
| return normalized_elements | |
| def _try_separate_elements(self, elem1: NormalizedElement, elem2: NormalizedElement, | |
| overlap_info: Dict): | |
| """Try to separate two significantly overlapping elements.""" | |
| bbox1 = elem1.normalized_bbox | |
| bbox2 = elem2.normalized_bbox | |
| overlap_x1 = max(bbox1[0], bbox2[0]) | |
| overlap_y1 = max(bbox1[1], bbox2[1]) | |
| overlap_x2 = min(bbox1[2], bbox2[2]) | |
| overlap_y2 = min(bbox1[3], bbox2[3]) | |
| overlap_width = overlap_x2 - overlap_x1 | |
| overlap_height = overlap_y2 - overlap_y1 | |
| center1_x = (bbox1[0] + bbox1[2]) / 2 | |
| center1_y = (bbox1[1] + bbox1[3]) / 2 | |
| center2_x = (bbox2[0] + bbox2[2]) / 2 | |
| center2_y = (bbox2[1] + bbox2[3]) / 2 | |
| dx = abs(center2_x - center1_x) | |
| dy = abs(center2_y - center1_y) | |
| min_gap = 3 | |
| if dx > dy: | |
| if center1_x < center2_x: | |
| midpoint = (bbox1[2] + bbox2[0]) / 2 | |
| bbox1[2] = midpoint - min_gap | |
| bbox2[0] = midpoint + min_gap | |
| else: | |
| midpoint = (bbox2[2] + bbox1[0]) / 2 | |
| bbox2[2] = midpoint - min_gap | |
| bbox1[0] = midpoint + min_gap | |
| else: | |
| if center1_y < center2_y: | |
| midpoint = (bbox1[3] + bbox2[1]) / 2 | |
| bbox1[3] = midpoint - min_gap | |
| bbox2[1] = midpoint + min_gap | |
| else: | |
| midpoint = (bbox2[3] + bbox1[1]) / 2 | |
| bbox2[3] = midpoint - min_gap | |
| bbox1[1] = midpoint + min_gap | |
| self._ensure_valid_bbox(bbox1) | |
| self._ensure_valid_bbox(bbox2) | |
| def _shrink_overlapping_edges(self, elem1: NormalizedElement, elem2: NormalizedElement, | |
| overlap_info: Dict): | |
| """Shrink overlapping edges for moderate overlaps.""" | |
| bbox1 = elem1.normalized_bbox | |
| bbox2 = elem2.normalized_bbox | |
| overlap_x1 = max(bbox1[0], bbox2[0]) | |
| overlap_y1 = max(bbox1[1], bbox2[1]) | |
| overlap_x2 = min(bbox1[2], bbox2[2]) | |
| overlap_y2 = min(bbox1[3], bbox2[3]) | |
| overlap_width = overlap_x2 - overlap_x1 | |
| overlap_height = overlap_y2 - overlap_y1 | |
| gap = 2 | |
| if overlap_width > overlap_height: | |
| shrink = overlap_width / 2 + gap | |
| if bbox1[0] < bbox2[0]: | |
| bbox1[2] -= shrink | |
| bbox2[0] += shrink | |
| else: | |
| bbox2[2] -= shrink | |
| bbox1[0] += shrink | |
| else: | |
| shrink = overlap_height / 2 + gap | |
| if bbox1[1] < bbox2[1]: | |
| bbox1[3] -= shrink | |
| bbox2[1] += shrink | |
| else: | |
| bbox2[3] -= shrink | |
| bbox1[1] += shrink | |
| self._ensure_valid_bbox(bbox1) | |
| self._ensure_valid_bbox(bbox2) | |
| def _ensure_valid_bbox(self, bbox: List[float]): | |
| """Ensure bounding box has minimum size and is within image bounds.""" | |
| min_size = 8 | |
| if bbox[2] - bbox[0] < min_size: | |
| center_x = (bbox[0] + bbox[2]) / 2 | |
| bbox[0] = center_x - min_size / 2 | |
| bbox[2] = center_x + min_size / 2 | |
| if bbox[3] - bbox[1] < min_size: | |
| center_y = (bbox[1] + bbox[3]) / 2 | |
| bbox[1] = center_y - min_size / 2 | |
| bbox[3] = center_y + min_size / 2 | |
| bbox[0] = max(0, min(bbox[0], self.img_width)) | |
| bbox[1] = max(0, min(bbox[1], self.img_height)) | |
| bbox[2] = max(0, min(bbox[2], self.img_width)) | |
| bbox[3] = max(0, min(bbox[3], self.img_height)) | |
| # ============================================================================ | |
| # MAIN NORMALIZATION ENGINE (unchanged) | |
| # ============================================================================ | |
| class LayoutNormalizer: | |
| """Main engine for normalizing wireframe layout.""" | |
| def __init__(self, elements: List[Element], img_width: float, img_height: float): | |
| self.elements = elements | |
| self.img_width = img_width | |
| self.img_height = img_height | |
| self.grid = GridLayoutSystem(img_width, img_height) | |
| self.alignment_detector = AlignmentDetector(elements) | |
| self.size_normalizer = SizeNormalizer(elements, img_width, img_height) | |
| def normalize_layout(self) -> List[NormalizedElement]: | |
| """Normalize all elements with proper sizing and alignment.""" | |
| print("\nπ§ Starting layout normalization...") | |
| h_alignments = self.alignment_detector.detect_horizontal_alignments() | |
| v_alignments = self.alignment_detector.detect_vertical_alignments() | |
| edge_alignments = self.alignment_detector.detect_edge_alignments() | |
| print(f"β Found {len(h_alignments)} horizontal alignment groups") | |
| print(f"β Found {len(v_alignments)} vertical alignment groups") | |
| size_clusters = self.size_normalizer.cluster_sizes_by_type() | |
| print(f"β Created size clusters for {len(size_clusters)} element types") | |
| element_to_cluster = {} | |
| element_to_size_category = {} | |
| for label, clusters in size_clusters.items(): | |
| for i, cluster in enumerate(clusters): | |
| category = f"{label}_size_{i + 1}" | |
| for elem in cluster: | |
| element_to_cluster[id(elem)] = cluster | |
| element_to_size_category[id(elem)] = category | |
| normalized_elements = [] | |
| for elem in self.elements: | |
| cluster = element_to_cluster.get(id(elem), [elem]) | |
| size_category = element_to_size_category.get(id(elem), f"{elem.label}_default") | |
| norm_width, norm_height = self.size_normalizer.get_normalized_size(elem, cluster) | |
| center_x, center_y = elem.center_x, elem.center_y | |
| norm_bbox = [ | |
| center_x - norm_width / 2, | |
| center_y - norm_height / 2, | |
| center_x + norm_width / 2, | |
| center_y + norm_height / 2 | |
| ] | |
| snapped_bbox = self.grid.snap_to_grid(norm_bbox, elem.label, preserve_size=True) | |
| grid_position = self.grid.get_grid_position(snapped_bbox) | |
| normalized_elements.append(NormalizedElement( | |
| original=elem, | |
| normalized_bbox=snapped_bbox, | |
| grid_position=grid_position, | |
| size_category=size_category | |
| )) | |
| normalized_elements = self._apply_alignment_corrections( | |
| normalized_elements, h_alignments, v_alignments, edge_alignments | |
| ) | |
| overlap_resolver = OverlapResolver(self.elements, self.img_width, self.img_height) | |
| normalized_elements = overlap_resolver.resolve_overlaps(normalized_elements) | |
| print(f"β Normalized {len(normalized_elements)} elements") | |
| return normalized_elements | |
| def _apply_alignment_corrections(self, normalized_elements: List[NormalizedElement], | |
| h_alignments: List[List[Element]], | |
| v_alignments: List[List[Element]], | |
| edge_alignments: Dict) -> List[NormalizedElement]: | |
| """Apply alignment corrections to normalized elements.""" | |
| elem_to_normalized = {id(ne.original): ne for ne in normalized_elements} | |
| for h_group in h_alignments: | |
| norm_group = [elem_to_normalized[id(e)] for e in h_group if id(e) in elem_to_normalized] | |
| if len(norm_group) > 1: | |
| avg_y = sum((ne.normalized_bbox[1] + ne.normalized_bbox[3]) / 2 for ne in norm_group) / len(norm_group) | |
| for ne in norm_group: | |
| height = ne.normalized_bbox[3] - ne.normalized_bbox[1] | |
| ne.normalized_bbox[1] = avg_y - height / 2 | |
| ne.normalized_bbox[3] = avg_y + height / 2 | |
| for v_group in v_alignments: | |
| norm_group = [elem_to_normalized[id(e)] for e in v_group if id(e) in elem_to_normalized] | |
| if len(norm_group) > 1: | |
| avg_x = sum((ne.normalized_bbox[0] + ne.normalized_bbox[2]) / 2 for ne in norm_group) / len(norm_group) | |
| for ne in norm_group: | |
| width = ne.normalized_bbox[2] - ne.normalized_bbox[0] | |
| ne.normalized_bbox[0] = avg_x - width / 2 | |
| ne.normalized_bbox[2] = avg_x + width / 2 | |
| for edge_type, groups in edge_alignments.items(): | |
| for edge_group in groups: | |
| norm_group = [elem_to_normalized[id(e)] for e in edge_group if id(e) in elem_to_normalized] | |
| if len(norm_group) > 1: | |
| if edge_type == 'left': | |
| avg_left = sum(ne.normalized_bbox[0] for ne in norm_group) / len(norm_group) | |
| for ne in norm_group: | |
| width = ne.normalized_bbox[2] - ne.normalized_bbox[0] | |
| ne.normalized_bbox[0] = avg_left | |
| ne.normalized_bbox[2] = avg_left + width | |
| elif edge_type == 'right': | |
| avg_right = sum(ne.normalized_bbox[2] for ne in norm_group) / len(norm_group) | |
| for ne in norm_group: | |
| width = ne.normalized_bbox[2] - ne.normalized_bbox[0] | |
| ne.normalized_bbox[2] = avg_right | |
| ne.normalized_bbox[0] = avg_right - width | |
| elif edge_type == 'top': | |
| avg_top = sum(ne.normalized_bbox[1] for ne in norm_group) / len(norm_group) | |
| for ne in norm_group: | |
| height = ne.normalized_bbox[3] - ne.normalized_bbox[1] | |
| ne.normalized_bbox[1] = avg_top | |
| ne.normalized_bbox[3] = avg_top + height | |
| elif edge_type == 'bottom': | |
| avg_bottom = sum(ne.normalized_bbox[3] for ne in norm_group) / len(norm_group) | |
| for ne in norm_group: | |
| height = ne.normalized_bbox[3] - ne.normalized_bbox[1] | |
| ne.normalized_bbox[3] = avg_bottom | |
| ne.normalized_bbox[1] = avg_bottom - height | |
| return normalized_elements | |
| # ============================================================================ | |
| # VISUALIZATION & EXPORT (unchanged) | |
| # ============================================================================ | |
| def visualize_comparison(pil_img: Image.Image, elements: List[Element], | |
| normalized_elements: List[NormalizedElement], | |
| grid_system: GridLayoutSystem): | |
| """Visualize original vs normalized layout.""" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) | |
| ax1.imshow(pil_img) | |
| ax1.set_title("Original Predictions", fontsize=16, weight='bold') | |
| ax1.axis('off') | |
| for elem in elements: | |
| x1, y1, x2, y2 = elem.bbox | |
| rect = patches.Rectangle( | |
| (x1, y1), x2 - x1, y2 - y1, | |
| linewidth=2, edgecolor='red', facecolor='none' | |
| ) | |
| ax1.add_patch(rect) | |
| ax1.text(x1, y1 - 5, elem.label, color='red', fontsize=8, | |
| bbox=dict(facecolor='white', alpha=0.7)) | |
| ax2.imshow(pil_img) | |
| ax2.set_title("Normalized & Aligned Layout", fontsize=16, weight='bold') | |
| ax2.axis('off') | |
| for x in range(grid_system.num_columns + 1): | |
| x_pos = x * grid_system.cell_width | |
| ax2.axvline(x=x_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3) | |
| for y in range(grid_system.num_rows + 1): | |
| y_pos = y * grid_system.cell_height | |
| ax2.axhline(y=y_pos, color='blue', linestyle=':', linewidth=0.5, alpha=0.3) | |
| np.random.seed(42) | |
| colors = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES))) | |
| color_map = {name: colors[i] for i, name in enumerate(CLASS_NAMES)} | |
| for norm_elem in normalized_elements: | |
| x1, y1, x2, y2 = norm_elem.normalized_bbox | |
| color = color_map[norm_elem.original.label] | |
| rect = patches.Rectangle( | |
| (x1, y1), x2 - x1, y2 - y1, | |
| linewidth=3, edgecolor=color, facecolor='none' | |
| ) | |
| ax2.add_patch(rect) | |
| ox1, oy1, ox2, oy2 = norm_elem.original.bbox | |
| orig_rect = patches.Rectangle( | |
| (ox1, oy1), ox2 - ox1, oy2 - oy1, | |
| linewidth=1, edgecolor='gray', facecolor='none', | |
| linestyle='--', alpha=0.5 | |
| ) | |
| ax2.add_patch(orig_rect) | |
| grid_pos = norm_elem.grid_position | |
| label_text = f"{norm_elem.original.label}\n{norm_elem.size_category}\nR{grid_pos['start_row']} C{grid_pos['start_col']}" | |
| ax2.text(x1 + 5, y1 + 15, label_text, color='white', fontsize=7, | |
| bbox=dict(facecolor=color, alpha=0.8, pad=2)) | |
| plt.tight_layout() | |
| plt.show() | |
| def export_to_json(normalized_elements: List[NormalizedElement], | |
| grid_system: GridLayoutSystem, | |
| output_path: str): | |
| """Export normalized layout to JSON.""" | |
| output = { | |
| 'metadata': { | |
| 'image_width': grid_system.img_width, | |
| 'image_height': grid_system.img_height, | |
| 'grid_system': { | |
| 'columns': grid_system.num_columns, | |
| 'rows': grid_system.num_rows, | |
| 'cell_width': round(grid_system.cell_width, 2), | |
| 'cell_height': round(grid_system.cell_height, 2) | |
| }, | |
| 'total_elements': len(normalized_elements) | |
| }, | |
| 'elements': [] | |
| } | |
| for i, norm_elem in enumerate(normalized_elements): | |
| orig = norm_elem.original | |
| norm_bbox = norm_elem.normalized_bbox | |
| element_data = { | |
| 'id': i, | |
| 'type': orig.label, | |
| 'confidence': round(orig.score, 3), | |
| 'size_category': norm_elem.size_category, | |
| 'original_bbox': { | |
| 'x1': round(orig.bbox[0], 2), | |
| 'y1': round(orig.bbox[1], 2), | |
| 'x2': round(orig.bbox[2], 2), | |
| 'y2': round(orig.bbox[3], 2), | |
| 'width': round(orig.width, 2), | |
| 'height': round(orig.height, 2) | |
| }, | |
| 'normalized_bbox': { | |
| 'x1': round(norm_bbox[0], 2), | |
| 'y1': round(norm_bbox[1], 2), | |
| 'x2': round(norm_bbox[2], 2), | |
| 'y2': round(norm_bbox[3], 2), | |
| 'width': round(norm_bbox[2] - norm_bbox[0], 2), | |
| 'height': round(norm_bbox[3] - norm_bbox[1], 2) | |
| }, | |
| 'grid_position': norm_elem.grid_position, | |
| 'percentage': { | |
| 'x1': round((norm_bbox[0] / grid_system.img_width) * 100, 2), | |
| 'y1': round((norm_bbox[1] / grid_system.img_height) * 100, 2), | |
| 'x2': round((norm_bbox[2] / grid_system.img_width) * 100, 2), | |
| 'y2': round((norm_bbox[3] / grid_system.img_height) * 100, 2) | |
| } | |
| } | |
| output['elements'].append(element_data) | |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) | |
| with open(output_path, 'w') as f: | |
| json.dump(output, f, indent=2) | |
| print(f"\nβ Exported normalized layout to: {output_path}") | |
| def export_to_html(normalized_elements: List[NormalizedElement], | |
| grid_system: GridLayoutSystem, | |
| output_path: str): | |
| """Export normalized layout as responsive HTML/CSS.""" | |
| html_template = """<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Wireframe Layout</title> | |
| <style> | |
| * {{ | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| }} | |
| body {{ | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif; | |
| background: #f5f5f5; | |
| padding: 20px; | |
| }} | |
| .container {{ | |
| max-width: {img_width}px; | |
| margin: 0 auto; | |
| background: white; | |
| position: relative; | |
| height: {img_height}px; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.1); | |
| }} | |
| .element {{ | |
| position: absolute; | |
| border: 2px solid #333; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-size: 12px; | |
| color: #666; | |
| background: rgba(255,255,255,0.9); | |
| transition: all 0.3s ease; | |
| }} | |
| .element:hover {{ | |
| z-index: 100; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.2); | |
| transform: scale(1.02); | |
| }} | |
| .element-label {{ | |
| font-weight: bold; | |
| font-size: 10px; | |
| text-transform: uppercase; | |
| }} | |
| .button {{ | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 6px; | |
| font-weight: bold; | |
| cursor: pointer; | |
| }} | |
| .checkbox {{ | |
| background: white; | |
| border: 2px solid #4a5568; | |
| border-radius: 4px; | |
| }} | |
| .textfield {{ | |
| background: white; | |
| border: 2px solid #cbd5e0; | |
| border-radius: 4px; | |
| padding: 8px; | |
| }} | |
| .text {{ | |
| background: transparent; | |
| border: 1px dashed #cbd5e0; | |
| color: #2d3748; | |
| }} | |
| .paragraph {{ | |
| background: transparent; | |
| border: 1px dashed #cbd5e0; | |
| color: #4a5568; | |
| text-align: left; | |
| padding: 8px; | |
| }} | |
| .image {{ | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
| color: white; | |
| border: none; | |
| }} | |
| .navbar {{ | |
| background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); | |
| color: white; | |
| font-weight: bold; | |
| border: none; | |
| }} | |
| .info-panel {{ | |
| position: fixed; | |
| top: 20px; | |
| right: 20px; | |
| background: white; | |
| padding: 20px; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.1); | |
| max-width: 300px; | |
| }} | |
| .info-panel h3 {{ | |
| margin-bottom: 10px; | |
| color: #2d3748; | |
| }} | |
| .info-panel p {{ | |
| margin: 5px 0; | |
| font-size: 14px; | |
| color: #4a5568; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="info-panel"> | |
| <h3>π Layout Info</h3> | |
| <p><strong>Grid:</strong> {grid_cols} Γ {grid_rows}</p> | |
| <p><strong>Elements:</strong> {total_elements}</p> | |
| <p><strong>Dimensions:</strong> {img_width}px Γ {img_height}px</p> | |
| <p style="margin-top: 15px; font-size: 12px; color: #718096;"> | |
| Hover over elements to see details | |
| </p> | |
| </div> | |
| <div class="container"> | |
| {elements_html} | |
| </div> | |
| </body> | |
| </html>""" | |
| elements_html = [] | |
| for i, norm_elem in enumerate(normalized_elements): | |
| x1, y1, x2, y2 = norm_elem.normalized_bbox | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| element_html = f""" | |
| <div class="element {norm_elem.original.label}" | |
| style="left: {x1}px; top: {y1}px; width: {width}px; height: {height}px;" | |
| title="{norm_elem.original.label} | Grid: R{norm_elem.grid_position['start_row']} C{norm_elem.grid_position['start_col']} | Size: {norm_elem.size_category}"> | |
| <span class="element-label">{norm_elem.original.label}</span> | |
| </div>""" | |
| elements_html.append(element_html) | |
| html_content = html_template.format( | |
| img_width=int(grid_system.img_width), | |
| img_height=int(grid_system.img_height), | |
| grid_cols=grid_system.num_columns, | |
| grid_rows=grid_system.num_rows, | |
| total_elements=len(normalized_elements), | |
| elements_html='\n'.join(elements_html) | |
| ) | |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| f.write(html_content) | |
| print(f"β Exported HTML layout to: {output_path}") | |
| # ============================================================================ | |
| # MAIN PIPELINE - MODIFIED FOR ONNX | |
| # ============================================================================ | |
| def process_wireframe(image_path: str, | |
| save_json: bool = True, | |
| save_html: bool = True, | |
| show_visualization: bool = True) -> Dict: | |
| """ | |
| Complete pipeline to process wireframe image. | |
| Args: | |
| image_path: Path to wireframe image | |
| save_json: Export normalized layout as JSON | |
| save_html: Export normalized layout as HTML | |
| show_visualization: Display matplotlib comparison | |
| Returns: | |
| Dictionary containing all processing results | |
| """ | |
| print("=== PROCESS_WIREFRAME START ===") | |
| print("Input image path:", image_path) | |
| print("File exists:", os.path.exists(image_path)) | |
| if os.path.exists(image_path): | |
| print("File size:", os.path.getsize(image_path)) | |
| print("=" * 80) | |
| print("π WIREFRAME LAYOUT NORMALIZER (ONNX)") | |
| print("=" * 80) | |
| # Step 1: Load ONNX model and get predictions | |
| global ort_session | |
| if ort_session is None: | |
| print("\nπ¦ Loading ONNX model...") | |
| print("Model path:", MODEL_PATH) | |
| print("Model path exists?", os.path.exists(MODEL_PATH)) | |
| try: | |
| ort_session = ort.InferenceSession(MODEL_PATH) | |
| print("β ONNX model loaded successfully!") | |
| print(f"Input name: {ort_session.get_inputs()[0].name}") | |
| print(f"Input shape: {ort_session.get_inputs()[0].shape}") | |
| print(f"Output name: {ort_session.get_outputs()[0].name}") | |
| print(f"Output shape: {ort_session.get_outputs()[0].shape}") | |
| except Exception as e: | |
| print(f"β Error loading ONNX model: {e}") | |
| return {} | |
| print(f"\nπΈ Processing image: {image_path}") | |
| print("Running detection inferenceβ¦") | |
| try: | |
| pil_img, elements = get_predictions(image_path) | |
| print(f"β Detected {len(elements)} elements") | |
| for elem in elements: | |
| print(f" - {elem.label} (conf: {elem.score:.3f}) at {elem.bbox}") | |
| except Exception as e: | |
| print(f"β Error during prediction: {e}") | |
| return {} | |
| if not elements: | |
| print("β οΈ No elements detected.") | |
| print("β Check thresholds:") | |
| print(f" CONF_THRESHOLD: {CONF_THRESHOLD}") | |
| print(f" IOU_THRESHOLD: {IOU_THRESHOLD}") | |
| return {} | |
| # Step 2: Normalize layout | |
| normalizer = LayoutNormalizer(elements, pil_img.width, pil_img.height) | |
| normalized_elements = normalizer.normalize_layout() | |
| # Step 3: Generate outputs | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| base_filename = os.path.splitext(os.path.basename(image_path))[0] | |
| results = { | |
| 'image': pil_img, | |
| 'original_elements': elements, | |
| 'normalized_elements': normalized_elements, | |
| 'grid_system': normalizer.grid | |
| } | |
| # Export JSON | |
| if save_json: | |
| json_path = os.path.join(OUTPUT_DIR, f"{base_filename}_normalized.json") | |
| export_to_json(normalized_elements, normalizer.grid, json_path) | |
| results['json_path'] = json_path | |
| # Export HTML | |
| if save_html: | |
| html_path = os.path.join(OUTPUT_DIR, f"{base_filename}_layout.html") | |
| export_to_html(normalized_elements, normalizer.grid, html_path) | |
| results['html_path'] = html_path | |
| # Visualize | |
| if show_visualization: | |
| print("\nπ¨ Generating visualization...") | |
| visualize_comparison(pil_img, elements, normalized_elements, normalizer.grid) | |
| # Print summary | |
| print("\n" + "=" * 80) | |
| print("π PROCESSING SUMMARY") | |
| print("=" * 80) | |
| type_counts = {} | |
| for elem in elements: | |
| type_counts[elem.label] = type_counts.get(elem.label, 0) + 1 | |
| print(f"\nπ¦ Element Types:") | |
| for elem_type, count in sorted(type_counts.items()): | |
| print(f" β’ {elem_type}: {count}") | |
| size_categories = {} | |
| for norm_elem in normalized_elements: | |
| size_categories[norm_elem.size_category] = size_categories.get(norm_elem.size_category, 0) + 1 | |
| print(f"\nπ Size Categories: {len(size_categories)}") | |
| h_alignments = normalizer.alignment_detector.detect_horizontal_alignments() | |
| v_alignments = normalizer.alignment_detector.detect_vertical_alignments() | |
| print(f"\nπ Alignment:") | |
| print(f" β’ Horizontal groups: {len(h_alignments)}") | |
| print(f" β’ Vertical groups: {len(v_alignments)}") | |
| print("\n" + "=" * 80) | |
| print("β PROCESSING COMPLETE!") | |
| print("=" * 80 + "\n") | |
| return results | |
| def batch_process(image_dir: str, pattern: str = "*.png"): | |
| """Process multiple wireframe images in a directory.""" | |
| import glob | |
| image_paths = glob.glob(os.path.join(image_dir, pattern)) | |
| if not image_paths: | |
| print(f"β No images found matching pattern: {pattern}") | |
| return | |
| print(f"π Found {len(image_paths)} images to process\n") | |
| all_results = [] | |
| for i, image_path in enumerate(image_paths, 1): | |
| print(f"\n{'=' * 80}") | |
| print(f"Processing image {i}/{len(image_paths)}: {os.path.basename(image_path)}") | |
| print(f"{'=' * 80}") | |
| try: | |
| results = process_wireframe( | |
| image_path, | |
| save_json=True, | |
| save_html=True, | |
| show_visualization=False | |
| ) | |
| all_results.append({ | |
| 'image_path': image_path, | |
| 'success': True, | |
| 'results': results | |
| }) | |
| except Exception as e: | |
| print(f"β Error processing {image_path}: {str(e)}") | |
| all_results.append({ | |
| 'image_path': image_path, | |
| 'success': False, | |
| 'error': str(e) | |
| }) | |
| successful = sum(1 for r in all_results if r['success']) | |
| print(f"\n{'=' * 80}") | |
| print(f"π BATCH PROCESSING COMPLETE") | |
| print(f"{'=' * 80}") | |
| print(f"β Successful: {successful}/{len(image_paths)}") | |
| print(f"β Failed: {len(image_paths) - successful}/{len(image_paths)}") | |
| return all_results | |
| # ============================================================================ | |
| # EXAMPLE USAGE | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| # Single image processing | |
| image_path = "./image/6LHls1vE.jpg" | |
| # Process with all outputs | |
| results = process_wireframe( | |
| image_path, | |
| save_json=True, | |
| save_html=True, | |
| show_visualization=True | |
| ) | |
| # Or batch process multiple images | |
| # batch_results = batch_process("./wireframes/", pattern="*.png") |