|
|
""" |
|
|
Scroll Element Dataset Loader for Drizz Mobile App Testing |
|
|
|
|
|
Loads screenshots with bounding boxes and commands to identify scroll elements. |
|
|
Converts to GEPA-compatible format for prompt optimization. |
|
|
""" |
|
|
|
|
|
import base64 |
|
|
import random |
|
|
import logging |
|
|
from typing import List, Dict, Any, Tuple, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ScrollDatasetLoader: |
|
|
""" |
|
|
GENERIC dataset loader for image-based tasks. |
|
|
|
|
|
This is a LIBRARY class - NO hardcoded assumptions about: |
|
|
- What the task is (OCR, element detection, classification, etc.) |
|
|
- Input format (questions, commands, descriptions, etc.) |
|
|
- Output format (IDs, text, JSON, etc.) |
|
|
|
|
|
Users define their dataset in the test script and pass it here. |
|
|
|
|
|
Dataset format per item: (image_filename, input_text, expected_output) |
|
|
|
|
|
Example usage (ANY task): |
|
|
# Define YOUR dataset in YOUR test script |
|
|
my_dataset = [ |
|
|
("img1.png", "What is the main color?", "blue"), |
|
|
("img2.png", "Count the objects", "5"), |
|
|
("img3.png", "Describe the scene", "A cat on a sofa"), |
|
|
] |
|
|
|
|
|
# Pass to loader |
|
|
loader = ScrollDatasetLoader( |
|
|
images_dir="images", |
|
|
dataset_config=my_dataset |
|
|
) |
|
|
data = loader.load_dataset() |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
images_dir: str = "images", |
|
|
dataset_config: Optional[List[Tuple[str, str, str]]] = None |
|
|
): |
|
|
""" |
|
|
Initialize dataset loader. |
|
|
|
|
|
Args: |
|
|
images_dir: Directory containing images |
|
|
dataset_config: List of (image_filename, input_text, expected_output) tuples. |
|
|
REQUIRED - no hardcoded defaults to keep library generic. |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If images_dir doesn't exist |
|
|
ValueError: If dataset_config is None |
|
|
""" |
|
|
self.images_dir = Path(images_dir) |
|
|
|
|
|
if not self.images_dir.exists(): |
|
|
raise FileNotFoundError(f"Images directory not found: {images_dir}") |
|
|
|
|
|
if dataset_config is None: |
|
|
raise ValueError( |
|
|
"dataset_config is required. This is a library class - define your " |
|
|
"dataset in the test script:\n" |
|
|
" dataset = [('img1.png', 'your input', 'expected output'), ...]\n" |
|
|
" loader = ScrollDatasetLoader(images_dir='...', dataset_config=dataset)" |
|
|
) |
|
|
|
|
|
self.dataset_config = dataset_config |
|
|
|
|
|
def load_dataset(self) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Load complete dataset with images. |
|
|
|
|
|
Phase 1: Includes element_id extraction from expected output. |
|
|
|
|
|
Returns: |
|
|
List of dataset items in GEPA format: |
|
|
[ |
|
|
{ |
|
|
"input": "Command: Scroll down by 70%", |
|
|
"output": "3", |
|
|
"image_base64": "<base64_encoded_image>", # TOP LEVEL |
|
|
"metadata": { |
|
|
"image_path": "images/5.png", |
|
|
"input_text": "Command: Scroll down by 70%", |
|
|
"expected_output": "3", |
|
|
"image_filename": "5.png", |
|
|
"element_id": 3 # Extracted integer (None if extraction fails) |
|
|
} |
|
|
}, |
|
|
... |
|
|
] |
|
|
""" |
|
|
dataset = [] |
|
|
|
|
|
|
|
|
for image_filename, input_text, expected_output in self.dataset_config: |
|
|
image_path = self.images_dir / image_filename |
|
|
|
|
|
|
|
|
if not image_path.exists(): |
|
|
logger.warning(f"Image not found: {image_path}") |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
image_base64 = self._encode_image(image_path) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error encoding {image_filename}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
element_id = self._extract_element_id(expected_output) |
|
|
if element_id is None: |
|
|
logger.warning(f"Could not extract element_id from '{expected_output}' in {image_filename}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_item = { |
|
|
"input": input_text, |
|
|
"output": expected_output, |
|
|
"image_base64": image_base64, |
|
|
"metadata": { |
|
|
"image_path": str(image_path), |
|
|
"input_text": input_text, |
|
|
"expected_output": expected_output, |
|
|
"image_filename": image_filename, |
|
|
"element_id": element_id |
|
|
} |
|
|
} |
|
|
|
|
|
dataset.append(dataset_item) |
|
|
|
|
|
if not dataset: |
|
|
raise ValueError("No valid images found in dataset") |
|
|
|
|
|
logger.info(f"Loaded {len(dataset)} scroll element detection samples") |
|
|
return dataset |
|
|
|
|
|
def _extract_element_id(self, expected_output: str) -> Optional[int]: |
|
|
""" |
|
|
Extract element ID from expected output string. |
|
|
|
|
|
Handles multiple formats: |
|
|
- "Element: 4" |
|
|
- "Element 4" |
|
|
- "4" (standalone) |
|
|
- "Element: 4, Description: ..." (full reasoning) |
|
|
|
|
|
Args: |
|
|
expected_output: Full expected output string with reasoning |
|
|
|
|
|
Returns: |
|
|
Element ID as integer, or None if not found |
|
|
""" |
|
|
import re |
|
|
|
|
|
if not expected_output: |
|
|
return None |
|
|
|
|
|
|
|
|
patterns = [ |
|
|
r'element[:\s]+(\d+)', |
|
|
r'\belement\s+(\d+)\b', |
|
|
] |
|
|
|
|
|
for pattern in patterns: |
|
|
match = re.search(pattern, expected_output, re.IGNORECASE) |
|
|
if match: |
|
|
try: |
|
|
element_id = int(match.group(1)) |
|
|
|
|
|
if 1 <= element_id <= 100: |
|
|
return element_id |
|
|
except (ValueError, IndexError): |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
number_match = re.search(r'\b(\d{1,3})\b', expected_output) |
|
|
if number_match: |
|
|
try: |
|
|
element_id = int(number_match.group(1)) |
|
|
if 1 <= element_id <= 100: |
|
|
return element_id |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def _encode_image(self, image_path: Path) -> str: |
|
|
""" |
|
|
Encode image to base64 string. |
|
|
|
|
|
Args: |
|
|
image_path: Path to image file |
|
|
|
|
|
Returns: |
|
|
Base64 encoded image string |
|
|
""" |
|
|
with open(image_path, "rb") as image_file: |
|
|
encoded = base64.b64encode(image_file.read()).decode('utf-8') |
|
|
return encoded |
|
|
|
|
|
def split_dataset( |
|
|
self, |
|
|
dataset: List[Dict[str, Any]], |
|
|
train_size: int = 4, |
|
|
val_size: int = 1, |
|
|
test_size: int = 1, |
|
|
shuffle: bool = True, |
|
|
seed: Optional[int] = None |
|
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: |
|
|
""" |
|
|
Split dataset into train, validation, and test sets. |
|
|
|
|
|
🔥 NEW: Added shuffling support to ensure different image distribution |
|
|
across splits, preventing hard images from always landing in validation set. |
|
|
|
|
|
Args: |
|
|
dataset: Complete dataset |
|
|
train_size: Number of samples for training (default: 4) |
|
|
val_size: Number of samples for validation (default: 1) |
|
|
test_size: Number of samples for test (default: 1) |
|
|
shuffle: Whether to shuffle dataset before splitting (default: True) |
|
|
seed: Random seed for reproducible shuffling (default: None = random) |
|
|
|
|
|
Returns: |
|
|
Tuple of (train_set, val_set, test_set) |
|
|
""" |
|
|
n = len(dataset) |
|
|
|
|
|
|
|
|
total_size = train_size + val_size + test_size |
|
|
if total_size > n: |
|
|
logger.warning(f"Requested split ({total_size}) exceeds dataset size ({n}). Adjusting split proportionally...") |
|
|
ratio = n / total_size |
|
|
train_size = int(train_size * ratio) |
|
|
val_size = int(val_size * ratio) |
|
|
test_size = n - train_size - val_size |
|
|
|
|
|
|
|
|
|
|
|
dataset_copy = dataset.copy() |
|
|
if shuffle: |
|
|
if seed is not None: |
|
|
random.seed(seed) |
|
|
logger.debug(f"Shuffling dataset with seed={seed} for reproducible splits") |
|
|
else: |
|
|
logger.debug(f"Shuffling dataset randomly (no seed)") |
|
|
random.shuffle(dataset_copy) |
|
|
else: |
|
|
logger.warning(f"Not shuffling dataset - using original order") |
|
|
|
|
|
|
|
|
train_set = dataset_copy[:train_size] |
|
|
val_set = dataset_copy[train_size:train_size + val_size] |
|
|
test_set = dataset_copy[train_size + val_size:train_size + val_size + test_size] |
|
|
|
|
|
logger.info(f"Dataset split: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test") |
|
|
|
|
|
|
|
|
if shuffle: |
|
|
train_images = [item['metadata'].get('image_filename', 'N/A') for item in train_set] |
|
|
val_images = [item['metadata'].get('image_filename', 'N/A') for item in val_set] |
|
|
test_images = [item['metadata'].get('image_filename', 'N/A') for item in test_set] |
|
|
print(f" Train images: {train_images[:5]}{'...' if len(train_images) > 5 else ''}") |
|
|
print(f" Val images: {val_images}") |
|
|
print(f" Test images: {test_images[:5]}{'...' if len(test_images) > 5 else ''}") |
|
|
|
|
|
return train_set, val_set, test_set |
|
|
|
|
|
|
|
|
def load_scroll_dataset( |
|
|
images_dir: str = "images", |
|
|
dataset_config: List[Tuple[str, str, str]] = None, |
|
|
split: bool = True |
|
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: |
|
|
""" |
|
|
Convenience function to load image-based dataset (GENERIC). |
|
|
|
|
|
Args: |
|
|
images_dir: Directory containing images |
|
|
dataset_config: List of (image_filename, input_text, expected_output) tuples |
|
|
split: Whether to split into train/val/test |
|
|
|
|
|
Returns: |
|
|
If split=True: (train_set, val_set, test_set) |
|
|
If split=False: (full_dataset, [], []) |
|
|
|
|
|
Example (works for ANY task): |
|
|
dataset_config = [ |
|
|
("img1.png", "What color is the sky?", "blue"), |
|
|
("img2.png", "Count the dogs", "2"), |
|
|
] |
|
|
train, val, test = load_scroll_dataset( |
|
|
images_dir="images", |
|
|
dataset_config=dataset_config |
|
|
) |
|
|
""" |
|
|
loader = ScrollDatasetLoader(images_dir, dataset_config=dataset_config) |
|
|
dataset = loader.load_dataset() |
|
|
|
|
|
if split: |
|
|
return loader.split_dataset(dataset) |
|
|
else: |
|
|
return dataset, [], [] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Testing Scroll Dataset Loader...") |
|
|
print("⚠️ NOTE: This is a library class. Define your dataset in your test script.") |
|
|
print("\nExample:") |
|
|
print(" dataset_config = [") |
|
|
print(" ('image1.png', 'Scroll down by 50%', '3'),") |
|
|
print(" ('image2.png', 'Swipe left', '4'),") |
|
|
print(" ]") |
|
|
print(" train, val, test = load_scroll_dataset(") |
|
|
print(" images_dir='images',") |
|
|
print(" dataset_config=dataset_config") |
|
|
print(" )") |
|
|
|
|
|
|