Spaces:
Sleeping
Sleeping
| import colorgram | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| import torch | |
| from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
| import functools | |
| class DesignTokenExtractor: | |
| def __init__(self): | |
| # Load models once at startup | |
| self.pix2struct_model = None | |
| self.pix2struct_processor = None | |
| self._load_models() | |
| def _load_models(self): | |
| """Load models with caching to prevent repeated initialization""" | |
| try: | |
| self.pix2struct_processor = Pix2StructProcessor.from_pretrained( | |
| "google/pix2struct-screen2words-base" | |
| ) | |
| self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained( | |
| "google/pix2struct-screen2words-base", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not load Pix2Struct model: {e}") | |
| # Continue without the model for basic extraction | |
| def extract_colors(self, image_path, num_colors=8): | |
| """Extract dominant colors using colorgram""" | |
| try: | |
| colors = colorgram.extract(image_path, num_colors) | |
| palette = {} | |
| for i, color in enumerate(colors): | |
| # Determine semantic color role based on proportion | |
| if i == 0 and color.proportion > 0.3: | |
| name = "background" | |
| elif i == 1: | |
| name = "primary" | |
| elif i == 2: | |
| name = "secondary" | |
| else: | |
| name = f"accent-{i-2}" | |
| palette[name] = { | |
| "hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}", | |
| "rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})", | |
| "proportion": round(color.proportion, 3) | |
| } | |
| return palette | |
| except Exception as e: | |
| print(f"Error extracting colors: {e}") | |
| return self._get_default_colors() | |
| def detect_spacing(self, image): | |
| """Analyze spacing patterns using OpenCV""" | |
| try: | |
| gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
| edges = cv2.Canny(gray, 50, 150) | |
| # Find contours for element detection | |
| contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Calculate spacing between elements | |
| bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100] | |
| if len(bounding_boxes) > 1: | |
| # Sort by y-coordinate to find vertical spacing | |
| bounding_boxes.sort(key=lambda x: x[1]) | |
| vertical_gaps = [] | |
| for i in range(len(bounding_boxes)-1): | |
| gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3]) | |
| if gap > 0: | |
| vertical_gaps.append(gap) | |
| # Find common spacing values using clustering | |
| spacing_system = self._cluster_spacing_values(vertical_gaps) | |
| return spacing_system | |
| except Exception as e: | |
| print(f"Error detecting spacing: {e}") | |
| return {"small": "8px", "medium": "16px", "large": "32px"} # Defaults | |
| def _cluster_spacing_values(self, gaps): | |
| """Group similar spacing values""" | |
| if not gaps: | |
| return {"small": "8px", "medium": "16px", "large": "32px"} | |
| gaps.sort() | |
| # Simple clustering for common spacing values | |
| unique_gaps = list(set(gaps)) | |
| if len(unique_gaps) >= 3: | |
| return { | |
| "small": f"{unique_gaps[0]}px", | |
| "medium": f"{unique_gaps[len(unique_gaps)//2]}px", | |
| "large": f"{unique_gaps[-1]}px" | |
| } | |
| elif len(unique_gaps) == 2: | |
| return { | |
| "small": f"{unique_gaps[0]}px", | |
| "large": f"{unique_gaps[1]}px" | |
| } | |
| return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"} | |
| def analyze_components(self, image): | |
| """Use Pix2Struct for component understanding""" | |
| if self.pix2struct_model is None or self.pix2struct_processor is None: | |
| # Fallback if model loading failed | |
| return { | |
| "detected_elements": "Model not available - basic extraction only", | |
| "layout": "responsive" | |
| } | |
| try: | |
| inputs = self.pix2struct_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| generated_ids = self.pix2struct_model.generate(**inputs, max_length=100) | |
| description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| # Parse description for component types | |
| components = { | |
| "detected_elements": description, | |
| "layout": "responsive" if "responsive" in description.lower() else "fixed" | |
| } | |
| return components | |
| except Exception as e: | |
| print(f"Error analyzing components: {e}") | |
| return { | |
| "detected_elements": "Error during analysis", | |
| "layout": "responsive" | |
| } | |
| def detect_typography(self, image): | |
| """Basic typography detection""" | |
| # Simplified typography detection without EasyOCR for initial implementation | |
| return { | |
| "heading": { | |
| "family": "sans-serif", | |
| "size": "32px", | |
| "weight": "700" | |
| }, | |
| "body": { | |
| "family": "sans-serif", | |
| "size": "16px", | |
| "weight": "400" | |
| }, | |
| "caption": { | |
| "family": "sans-serif", | |
| "size": "14px", | |
| "weight": "400" | |
| } | |
| } | |
| def _get_default_colors(self): | |
| """Return default color palette""" | |
| return { | |
| "primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25}, | |
| "secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15}, | |
| "background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40}, | |
| "text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20} | |
| } | |
| def resize_for_processing(self, image, max_dimension=1024): | |
| """Resize large images while maintaining aspect ratio""" | |
| if max(image.size) > max_dimension: | |
| ratio = max_dimension / max(image.size) | |
| new_size = tuple(int(dim * ratio) for dim in image.size) | |
| return image.resize(new_size, Image.Resampling.LANCZOS) | |
| return image |