File size: 5,331 Bytes
188709e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import json
import logging
from typing import Iterator, Dict, Any

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("UIControlDataset")

try:
    import tensorflow as tf
    from huggingface_hub import snapshot_download
except ImportError:
    logger.warning("TensorFlow or HuggingFace Hub is not installed. To run this dataset parser, install them using: pip install tensorflow huggingface_hub")

class UIControlDatasetManager:
    """
    Manager to download, stream, and parse the Google Android UI Control Dataset from Hugging Face.
    Used to train and fine-tune Aura Assist's screen automation and context parsing engines.
    
    Paper Reference: "On the Effects of Data Scale on UI Control Agents" (Wei Li et al., 2024)
    """
    def __init__(self, cache_dir: str = "./memory/datasets/android_control"):
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

    def download_dataset(self, repo_id: str = "google/android_control") -> str:
        """
        Downloads the dataset files securely from Hugging Face Hub using snapshot_download.
        Returns the absolute path to the downloaded directory containing TFRecords.
        """
        logger.info(f"Downloading Android UI Control Dataset from HF repo: '{repo_id}'...")
        try:
            download_path = snapshot_download(
                repo_id=repo_id,
                repo_type="dataset",
                cache_dir=self.cache_dir,
                allow_patterns=["android_control*"]
            )
            logger.info(f"Dataset downloaded successfully to: {download_path}")
            return download_path
        except Exception as e:
            logger.error(f"Failed downloading from Hugging Face Hub: {e}")
            # Fallback to returning local directory path
            return self.cache_dir

    def _parse_tfrecord_example(self, serialized_example: bytes) -> Dict[str, Any]:
        """
        Parses a single serialized TFRecord example from the Android Control Dataset structure.
        Extracts tasks, screen dimensions, visual layout view hierarchies, bounding boxes, and action targets.
        """
        # Description of features in the TFRecord schema for Android UI Control Agents
        feature_description = {
            'task_instruction': tf.io.FixedLenFeature([], tf.string),
            'screenshot_png': tf.io.FixedLenFeature([], tf.string, default_value=''),
            'ui_hierarchy_json': tf.io.FixedLenFeature([], tf.string, default_value=''),
            'action_type': tf.io.FixedLenFeature([], tf.string),
            'action_touch_x': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
            'action_touch_y': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
            'app_package': tf.io.FixedLenFeature([], tf.string, default_value=''),
        }
        
        parsed = tf.io.parse_single_example(serialized_example, feature_description)
        
        # Decode strings safely
        task_instruction = parsed['task_instruction'].numpy().decode('utf-8')
        ui_hierarchy = parsed['ui_hierarchy_json'].numpy().decode('utf-8')
        action_type = parsed['action_type'].numpy().decode('utf-8')
        app_package = parsed['app_package'].numpy().decode('utf-8')
        
        # Try parsing JSON hierarchy
        hierarchy_dict = {}
        if ui_hierarchy:
            try:
                hierarchy_dict = json.loads(ui_hierarchy)
            except json.JSONDecodeError:
                hierarchy_dict = {"raw": ui_hierarchy}

        return {
            "task": task_instruction,
            "app_package": app_package,
            "action": {
                "type": action_type,
                "x": float(parsed['action_touch_x'].numpy()),
                "y": float(parsed['action_touch_y'].numpy()),
            },
            "ui_hierarchy": hierarchy_dict
        }

    def stream_records(self, dataset_dir: str) -> Iterator[Dict[str, Any]]:
        """
        Locates the GZIP compressed TFRecord files and streams parsed structured screen-action pairs.
        Can be fed directly into a fine-tuning dataset creator or Aura Assist's testing layers.
        """
        # Resolve path
        glob_pattern = os.path.join(dataset_dir, "android_control*")
        filenames = tf.io.gfile.glob(glob_pattern)
        
        if not filenames:
            logger.warning(f"No compressed TFRecord files matching 'android_control*' found in {dataset_dir}")
            return
        
        logger.info(f"Streaming {len(filenames)} GZIP compressed TFRecord segments...")
        
        # Create standard TFRecord dataset
        raw_dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP')
        
        for raw_record in raw_dataset:
            try:
                # Process example using eager execution
                yield self._parse_tfrecord_example(raw_record.numpy())
            except Exception as e:
                logger.error(f"Skipping corrupt record example: {e}")
                continue

# Quick test execution scaffold
if __name__ == "__main__":
    manager = UIControlDatasetManager()
    logger.info("Initializing UI Control Dataset Manager...")
    print("Dataset Manager Ready. Ready to train AURA on app navigation paths!")