Suhasdev's picture
Deploy Universal Prompt Optimizer to HF Spaces (clean)
cacd4d0
"""
Scroll Element Dataset Loader for Drizz Mobile App Testing
Loads screenshots with bounding boxes and commands to identify scroll elements.
Converts to GEPA-compatible format for prompt optimization.
"""
import base64
import random
import logging
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class ScrollDatasetLoader:
"""
GENERIC dataset loader for image-based tasks.
This is a LIBRARY class - NO hardcoded assumptions about:
- What the task is (OCR, element detection, classification, etc.)
- Input format (questions, commands, descriptions, etc.)
- Output format (IDs, text, JSON, etc.)
Users define their dataset in the test script and pass it here.
Dataset format per item: (image_filename, input_text, expected_output)
Example usage (ANY task):
# Define YOUR dataset in YOUR test script
my_dataset = [
("img1.png", "What is the main color?", "blue"),
("img2.png", "Count the objects", "5"),
("img3.png", "Describe the scene", "A cat on a sofa"),
]
# Pass to loader
loader = ScrollDatasetLoader(
images_dir="images",
dataset_config=my_dataset
)
data = loader.load_dataset()
"""
def __init__(
self,
images_dir: str = "images",
dataset_config: Optional[List[Tuple[str, str, str]]] = None
):
"""
Initialize dataset loader.
Args:
images_dir: Directory containing images
dataset_config: List of (image_filename, input_text, expected_output) tuples.
REQUIRED - no hardcoded defaults to keep library generic.
Raises:
FileNotFoundError: If images_dir doesn't exist
ValueError: If dataset_config is None
"""
self.images_dir = Path(images_dir)
if not self.images_dir.exists():
raise FileNotFoundError(f"Images directory not found: {images_dir}")
if dataset_config is None:
raise ValueError(
"dataset_config is required. This is a library class - define your "
"dataset in the test script:\n"
" dataset = [('img1.png', 'your input', 'expected output'), ...]\n"
" loader = ScrollDatasetLoader(images_dir='...', dataset_config=dataset)"
)
self.dataset_config = dataset_config
def load_dataset(self) -> List[Dict[str, Any]]:
"""
Load complete dataset with images.
Phase 1: Includes element_id extraction from expected output.
Returns:
List of dataset items in GEPA format:
[
{
"input": "Command: Scroll down by 70%",
"output": "3",
"image_base64": "<base64_encoded_image>", # TOP LEVEL
"metadata": {
"image_path": "images/5.png",
"input_text": "Command: Scroll down by 70%",
"expected_output": "3",
"image_filename": "5.png",
"element_id": 3 # Extracted integer (None if extraction fails)
}
},
...
]
"""
dataset = []
# Generic variable names - no assumptions about data type
for image_filename, input_text, expected_output in self.dataset_config:
image_path = self.images_dir / image_filename
# Validate image exists
if not image_path.exists():
logger.warning(f"Image not found: {image_path}")
continue
# Read and encode image
try:
image_base64 = self._encode_image(image_path)
except Exception as e:
logger.warning(f"Error encoding {image_filename}: {e}")
continue
# 🔥 Phase 1: Extract element_id from expected_output for robust evaluation
element_id = self._extract_element_id(expected_output)
if element_id is None:
logger.warning(f"Could not extract element_id from '{expected_output}' in {image_filename}")
# Create dataset item - COMPLETELY GENERIC
# NO assumptions about output format (element IDs, commands, etc.)
# Just: image + input text + expected output text
# Library doesn't know or care what the task is!
# IMPORTANT: Put image_base64 at TOP LEVEL for UniversalConverter to find it
dataset_item = {
"input": input_text, # Generic input text (ANY format)
"output": expected_output, # Generic expected output (ANY format, full reasoning)
"image_base64": image_base64, # TOP LEVEL for converter
"metadata": {
"image_path": str(image_path),
"input_text": input_text,
"expected_output": expected_output,
"image_filename": image_filename,
"element_id": element_id # NEW: Extracted element ID (int or None)
}
}
dataset.append(dataset_item)
if not dataset:
raise ValueError("No valid images found in dataset")
logger.info(f"Loaded {len(dataset)} scroll element detection samples")
return dataset
def _extract_element_id(self, expected_output: str) -> Optional[int]:
"""
Extract element ID from expected output string.
Handles multiple formats:
- "Element: 4"
- "Element 4"
- "4" (standalone)
- "Element: 4, Description: ..." (full reasoning)
Args:
expected_output: Full expected output string with reasoning
Returns:
Element ID as integer, or None if not found
"""
import re
if not expected_output:
return None
# Pattern 1: "Element: X" or "Element X" (case insensitive)
patterns = [
r'element[:\s]+(\d+)', # "Element: 4" or "Element 4"
r'\belement\s+(\d+)\b', # "element 4" (word boundary)
]
for pattern in patterns:
match = re.search(pattern, expected_output, re.IGNORECASE)
if match:
try:
element_id = int(match.group(1))
# Validate range (reasonable UI element IDs)
if 1 <= element_id <= 100:
return element_id
except (ValueError, IndexError):
continue
# Pattern 2: First standalone number (if no "Element:" pattern found)
# Only use if it's a reasonable element ID (1-100)
number_match = re.search(r'\b(\d{1,3})\b', expected_output)
if number_match:
try:
element_id = int(number_match.group(1))
if 1 <= element_id <= 100: # Reasonable range for UI elements
return element_id
except ValueError:
pass
return None
def _encode_image(self, image_path: Path) -> str:
"""
Encode image to base64 string.
Args:
image_path: Path to image file
Returns:
Base64 encoded image string
"""
with open(image_path, "rb") as image_file:
encoded = base64.b64encode(image_file.read()).decode('utf-8')
return encoded
def split_dataset(
self,
dataset: List[Dict[str, Any]],
train_size: int = 4,
val_size: int = 1,
test_size: int = 1,
shuffle: bool = True,
seed: Optional[int] = None
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Split dataset into train, validation, and test sets.
🔥 NEW: Added shuffling support to ensure different image distribution
across splits, preventing hard images from always landing in validation set.
Args:
dataset: Complete dataset
train_size: Number of samples for training (default: 4)
val_size: Number of samples for validation (default: 1)
test_size: Number of samples for test (default: 1)
shuffle: Whether to shuffle dataset before splitting (default: True)
seed: Random seed for reproducible shuffling (default: None = random)
Returns:
Tuple of (train_set, val_set, test_set)
"""
n = len(dataset)
# Validate split sizes
total_size = train_size + val_size + test_size
if total_size > n:
logger.warning(f"Requested split ({total_size}) exceeds dataset size ({n}). Adjusting split proportionally...")
ratio = n / total_size
train_size = int(train_size * ratio)
val_size = int(val_size * ratio)
test_size = n - train_size - val_size
# 🔥 CRITICAL: Shuffle dataset to ensure different image distribution
# This prevents the same hard images from always being in validation set
dataset_copy = dataset.copy() # Don't modify original
if shuffle:
if seed is not None:
random.seed(seed)
logger.debug(f"Shuffling dataset with seed={seed} for reproducible splits")
else:
logger.debug(f"Shuffling dataset randomly (no seed)")
random.shuffle(dataset_copy)
else:
logger.warning(f"Not shuffling dataset - using original order")
# Split shuffled dataset
train_set = dataset_copy[:train_size]
val_set = dataset_copy[train_size:train_size + val_size]
test_set = dataset_copy[train_size + val_size:train_size + val_size + test_size]
logger.info(f"Dataset split: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test")
# Log which images are in each split for debugging
if shuffle:
train_images = [item['metadata'].get('image_filename', 'N/A') for item in train_set]
val_images = [item['metadata'].get('image_filename', 'N/A') for item in val_set]
test_images = [item['metadata'].get('image_filename', 'N/A') for item in test_set]
print(f" Train images: {train_images[:5]}{'...' if len(train_images) > 5 else ''}")
print(f" Val images: {val_images}")
print(f" Test images: {test_images[:5]}{'...' if len(test_images) > 5 else ''}")
return train_set, val_set, test_set
def load_scroll_dataset(
images_dir: str = "images",
dataset_config: List[Tuple[str, str, str]] = None,
split: bool = True
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Convenience function to load image-based dataset (GENERIC).
Args:
images_dir: Directory containing images
dataset_config: List of (image_filename, input_text, expected_output) tuples
split: Whether to split into train/val/test
Returns:
If split=True: (train_set, val_set, test_set)
If split=False: (full_dataset, [], [])
Example (works for ANY task):
dataset_config = [
("img1.png", "What color is the sky?", "blue"),
("img2.png", "Count the dogs", "2"),
]
train, val, test = load_scroll_dataset(
images_dir="images",
dataset_config=dataset_config
)
"""
loader = ScrollDatasetLoader(images_dir, dataset_config=dataset_config)
dataset = loader.load_dataset()
if split:
return loader.split_dataset(dataset)
else:
return dataset, [], []
# Example usage (for testing the library loader itself)
if __name__ == "__main__":
print("🚀 Testing Scroll Dataset Loader...")
print("⚠️ NOTE: This is a library class. Define your dataset in your test script.")
print("\nExample:")
print(" dataset_config = [")
print(" ('image1.png', 'Scroll down by 50%', '3'),")
print(" ('image2.png', 'Swipe left', '4'),")
print(" ]")
print(" train, val, test = load_scroll_dataset(")
print(" images_dir='images',")
print(" dataset_config=dataset_config")
print(" )")