import os import requests import io import colorsys import random from PIL import Image from datasets import load_dataset def rgb_to_hsv(rgb): """Convert RGB [0-255] to HSV tuple (h, s, v).""" r, g, b = [c / 255.0 for c in rgb] return colorsys.rgb_to_hsv(r, g, b) def dominant_hue(palettes): """Return dominant hue (0–360) from palettes dict of RGB lists.""" all_colors = [] for k, colors in palettes.items(): all_colors.extend(colors) if not all_colors: return None hsv_vals = [rgb_to_hsv(c) for c in all_colors] hues = [h * 360 for (h, s, v) in hsv_vals if s > 0.2 and v > 0.2] if not hues: return None return sum(hues) / len(hues) def create_balanced_tiles(limit=500, tile_size=64, bins=12): """Create a balanced set of colored tiles from Unsplash-Lite-Palette.""" os.makedirs("data/tiles", exist_ok=True) print("Loading Unsplash-Lite-Palette dataset...") ds = load_dataset("1aurent/unsplash-lite-palette", split="train") hue_bins = {i: [] for i in range(bins)} for row in ds: h = dominant_hue(row["palettes"]) if h is not None: bin_idx = int(h / (360 / bins)) % bins hue_bins[bin_idx].append(row) print(f"Dataset grouped into {bins} hue bins.") per_bin = max(1, limit // bins) selected = [] for b in range(bins): random.shuffle(hue_bins[b]) selected.extend(hue_bins[b][:per_bin]) if len(selected) < limit: extra_needed = limit - len(selected) all_rows = [row for bin_rows in hue_bins.values() for row in bin_rows] random.shuffle(all_rows) selected.extend(all_rows[:extra_needed]) print(f"Selected {len(selected)} images across hue bins.") successful = 0 for i, row in enumerate(selected): url = row["url"] filename = f"tile_{i:04d}.jpg" try: resp = requests.get(url, timeout=10) resp.raise_for_status() img = Image.open(io.BytesIO(resp.content)).convert("RGB") w, h = img.size if w != h: min_side = min(w, h) left = (w - min_side) // 2 top = (h - min_side) // 2 img = img.crop((left, top, left + min_side, top + min_side)) img = img.resize((tile_size, tile_size), Image.Resampling.LANCZOS) filepath = os.path.join("data/tiles", filename) img.save(filepath, "JPEG", quality=85, optimize=True) successful += 1 if i % 50 == 0: print(f"✓ {successful}/{i+1} saved ({filename})") except Exception as e: print(f"✗ Failed {i} ({url}): {e}") print(f"\nDone! {successful} balanced color tiles saved in data/tiles/") if __name__ == "__main__": create_balanced_tiles(limit=500, tile_size=64, bins=12)