File size: 12,820 Bytes
cacd4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
Scroll Element Dataset Loader for Drizz Mobile App Testing

Loads screenshots with bounding boxes and commands to identify scroll elements.
Converts to GEPA-compatible format for prompt optimization.
"""

import base64
import random
import logging
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path

logger = logging.getLogger(__name__)


class ScrollDatasetLoader:
    """
    GENERIC dataset loader for image-based tasks.
    
    This is a LIBRARY class - NO hardcoded assumptions about:
    - What the task is (OCR, element detection, classification, etc.)
    - Input format (questions, commands, descriptions, etc.)
    - Output format (IDs, text, JSON, etc.)
    
    Users define their dataset in the test script and pass it here.
    
    Dataset format per item: (image_filename, input_text, expected_output)
    
    Example usage (ANY task):
        # Define YOUR dataset in YOUR test script
        my_dataset = [
            ("img1.png", "What is the main color?", "blue"),
            ("img2.png", "Count the objects", "5"),
            ("img3.png", "Describe the scene", "A cat on a sofa"),
        ]
        
        # Pass to loader
        loader = ScrollDatasetLoader(
            images_dir="images",
            dataset_config=my_dataset
        )
        data = loader.load_dataset()
    """
    
    def __init__(
        self,
        images_dir: str = "images",
        dataset_config: Optional[List[Tuple[str, str, str]]] = None
    ):
        """
        Initialize dataset loader.
        
        Args:
            images_dir: Directory containing images
            dataset_config: List of (image_filename, input_text, expected_output) tuples.
                           REQUIRED - no hardcoded defaults to keep library generic.
        
        Raises:
            FileNotFoundError: If images_dir doesn't exist
            ValueError: If dataset_config is None
        """
        self.images_dir = Path(images_dir)
        
        if not self.images_dir.exists():
            raise FileNotFoundError(f"Images directory not found: {images_dir}")
        
        if dataset_config is None:
            raise ValueError(
                "dataset_config is required. This is a library class - define your "
                "dataset in the test script:\n"
                "  dataset = [('img1.png', 'your input', 'expected output'), ...]\n"
                "  loader = ScrollDatasetLoader(images_dir='...', dataset_config=dataset)"
            )
        
        self.dataset_config = dataset_config
    
    def load_dataset(self) -> List[Dict[str, Any]]:
        """
        Load complete dataset with images.
        
        Phase 1: Includes element_id extraction from expected output.
        
        Returns:
            List of dataset items in GEPA format:
            [
                {
                    "input": "Command: Scroll down by 70%",
                    "output": "3",
                    "image_base64": "<base64_encoded_image>",  # TOP LEVEL
                    "metadata": {
                        "image_path": "images/5.png",
                        "input_text": "Command: Scroll down by 70%",
                        "expected_output": "3",
                        "image_filename": "5.png",
                        "element_id": 3  # Extracted integer (None if extraction fails)
                    }
                },
                ...
            ]
        """
        dataset = []
        
        # Generic variable names - no assumptions about data type
        for image_filename, input_text, expected_output in self.dataset_config:
            image_path = self.images_dir / image_filename
            
            # Validate image exists
            if not image_path.exists():
                logger.warning(f"Image not found: {image_path}")
                continue
            
            # Read and encode image
            try:
                image_base64 = self._encode_image(image_path)
            except Exception as e:
                logger.warning(f"Error encoding {image_filename}: {e}")
                continue
            
            # 🔥 Phase 1: Extract element_id from expected_output for robust evaluation
            element_id = self._extract_element_id(expected_output)
            if element_id is None:
                logger.warning(f"Could not extract element_id from '{expected_output}' in {image_filename}")
            
            # Create dataset item - COMPLETELY GENERIC
            # NO assumptions about output format (element IDs, commands, etc.)
            # Just: image + input text + expected output text
            # Library doesn't know or care what the task is!
            # IMPORTANT: Put image_base64 at TOP LEVEL for UniversalConverter to find it
            dataset_item = {
                "input": input_text,  # Generic input text (ANY format)
                "output": expected_output,  # Generic expected output (ANY format, full reasoning)
                "image_base64": image_base64,  # TOP LEVEL for converter
                "metadata": {
                    "image_path": str(image_path),
                    "input_text": input_text,
                    "expected_output": expected_output,
                    "image_filename": image_filename,
                    "element_id": element_id  # NEW: Extracted element ID (int or None)
                }
            }
            
            dataset.append(dataset_item)
        
        if not dataset:
            raise ValueError("No valid images found in dataset")
        
        logger.info(f"Loaded {len(dataset)} scroll element detection samples")
        return dataset
    
    def _extract_element_id(self, expected_output: str) -> Optional[int]:
        """
        Extract element ID from expected output string.
        
        Handles multiple formats:
        - "Element: 4"
        - "Element 4"
        - "4" (standalone)
        - "Element: 4, Description: ..." (full reasoning)
        
        Args:
            expected_output: Full expected output string with reasoning
            
        Returns:
            Element ID as integer, or None if not found
        """
        import re
        
        if not expected_output:
            return None
        
        # Pattern 1: "Element: X" or "Element X" (case insensitive)
        patterns = [
            r'element[:\s]+(\d+)',  # "Element: 4" or "Element 4"
            r'\belement\s+(\d+)\b',  # "element 4" (word boundary)
        ]
        
        for pattern in patterns:
            match = re.search(pattern, expected_output, re.IGNORECASE)
            if match:
                try:
                    element_id = int(match.group(1))
                    # Validate range (reasonable UI element IDs)
                    if 1 <= element_id <= 100:
                        return element_id
                except (ValueError, IndexError):
                    continue
        
        # Pattern 2: First standalone number (if no "Element:" pattern found)
        # Only use if it's a reasonable element ID (1-100)
        number_match = re.search(r'\b(\d{1,3})\b', expected_output)
        if number_match:
            try:
                element_id = int(number_match.group(1))
                if 1 <= element_id <= 100:  # Reasonable range for UI elements
                    return element_id
            except ValueError:
                pass
        
        return None
    
    def _encode_image(self, image_path: Path) -> str:
        """
        Encode image to base64 string.
        
        Args:
            image_path: Path to image file
            
        Returns:
            Base64 encoded image string
        """
        with open(image_path, "rb") as image_file:
            encoded = base64.b64encode(image_file.read()).decode('utf-8')
        return encoded
    
    def split_dataset(
        self, 
        dataset: List[Dict[str, Any]], 
        train_size: int = 4,
        val_size: int = 1,
        test_size: int = 1,
        shuffle: bool = True,
        seed: Optional[int] = None
    ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
        """
        Split dataset into train, validation, and test sets.
        
        🔥 NEW: Added shuffling support to ensure different image distribution
        across splits, preventing hard images from always landing in validation set.
        
        Args:
            dataset: Complete dataset
            train_size: Number of samples for training (default: 4)
            val_size: Number of samples for validation (default: 1)
            test_size: Number of samples for test (default: 1)
            shuffle: Whether to shuffle dataset before splitting (default: True)
            seed: Random seed for reproducible shuffling (default: None = random)
            
        Returns:
            Tuple of (train_set, val_set, test_set)
        """
        n = len(dataset)
        
        # Validate split sizes
        total_size = train_size + val_size + test_size
        if total_size > n:
            logger.warning(f"Requested split ({total_size}) exceeds dataset size ({n}). Adjusting split proportionally...")
            ratio = n / total_size
            train_size = int(train_size * ratio)
            val_size = int(val_size * ratio)
            test_size = n - train_size - val_size
        
        # 🔥 CRITICAL: Shuffle dataset to ensure different image distribution
        # This prevents the same hard images from always being in validation set
        dataset_copy = dataset.copy()  # Don't modify original
        if shuffle:
            if seed is not None:
                random.seed(seed)
                logger.debug(f"Shuffling dataset with seed={seed} for reproducible splits")
            else:
                logger.debug(f"Shuffling dataset randomly (no seed)")
            random.shuffle(dataset_copy)
        else:
            logger.warning(f"Not shuffling dataset - using original order")
        
        # Split shuffled dataset
        train_set = dataset_copy[:train_size]
        val_set = dataset_copy[train_size:train_size + val_size]
        test_set = dataset_copy[train_size + val_size:train_size + val_size + test_size]
        
        logger.info(f"Dataset split: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test")
        
        # Log which images are in each split for debugging
        if shuffle:
            train_images = [item['metadata'].get('image_filename', 'N/A') for item in train_set]
            val_images = [item['metadata'].get('image_filename', 'N/A') for item in val_set]
            test_images = [item['metadata'].get('image_filename', 'N/A') for item in test_set]
            print(f"   Train images: {train_images[:5]}{'...' if len(train_images) > 5 else ''}")
            print(f"   Val images:   {val_images}")
            print(f"   Test images:  {test_images[:5]}{'...' if len(test_images) > 5 else ''}")
        
        return train_set, val_set, test_set


def load_scroll_dataset(
    images_dir: str = "images",
    dataset_config: List[Tuple[str, str, str]] = None,
    split: bool = True
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Convenience function to load image-based dataset (GENERIC).
    
    Args:
        images_dir: Directory containing images
        dataset_config: List of (image_filename, input_text, expected_output) tuples
        split: Whether to split into train/val/test
        
    Returns:
        If split=True: (train_set, val_set, test_set)
        If split=False: (full_dataset, [], [])
        
    Example (works for ANY task):
        dataset_config = [
            ("img1.png", "What color is the sky?", "blue"),
            ("img2.png", "Count the dogs", "2"),
        ]
        train, val, test = load_scroll_dataset(
            images_dir="images",
            dataset_config=dataset_config
        )
    """
    loader = ScrollDatasetLoader(images_dir, dataset_config=dataset_config)
    dataset = loader.load_dataset()
    
    if split:
        return loader.split_dataset(dataset)
    else:
        return dataset, [], []


# Example usage (for testing the library loader itself)
if __name__ == "__main__":
    print("🚀 Testing Scroll Dataset Loader...")
    print("⚠️  NOTE: This is a library class. Define your dataset in your test script.")
    print("\nExample:")
    print("  dataset_config = [")
    print("      ('image1.png', 'Scroll down by 50%', '3'),")
    print("      ('image2.png', 'Swipe left', '4'),")
    print("  ]")
    print("  train, val, test = load_scroll_dataset(")
    print("      images_dir='images',")
    print("      dataset_config=dataset_config")
    print("  )")