File size: 8,843 Bytes
ab974a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152c46e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab974a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
---
language: en
license: apache-2.0
base_model: google/gemma-4-E4B-it
tags:
  - android
  - ui-automation
  - accessibility
  - lora
  - peft
---

# Cerebellum — Android UI Action Predictor

LoRA adapter on top of `google/gemma-4-E4B-it` that predicts the next Android UI action given a screenshot and accessibility tree.

**Architecture:** The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model.

---

## What It Does

Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next.

---

## Input Format

The model uses a chat-style prompt (Gemma4 format). The user turn is structured as:

```
Task: {goal}

Step 1 (past): <|image|> -> {action_text}
Step 2 (past): <|image|> -> {action_text}
...
Current screen: <|image|>
{compressed_accessibility_tree}
[n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br  ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here)
Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible
Next action:
```

**Inputs:**
- `goal` — natural language task description (e.g. "Open the settings app and enable dark mode")
- `history` — up to 4 past (screenshot, action) pairs; can be empty
- `current screenshot` — PIL image of the current screen, resized to 896px on the long edge
- `compressed_accessibility_tree` — compact text representation of the UI element tree (see below)

### Accessibility Tree Format

Each interactive element is one line:

```
[0 btn tl] Settings
[1 ed mc fc=focused] Search...
[2 btn sr tr] More options
```

Fields per element:
- `[n]` — element index (used in action codes)
- type: `bt`=Button, `ed`=EditText, `tx`=TextView, `im`=ImageView, `ck`=CheckBox, `sw`=Switch, `rd`=RadioButton, `sp`=Spinner, `sc`=ScrollView, `ls`=ListView/RecyclerView, `bar`=Toolbar, `tab`=TabLayout, `dw`=DrawerLayout, `vw`=other
- zone: screen position of element center — row (`t`=top, `m`=mid, `b`=bottom) + col (`l`=left, `c`=center, `r`=right), e.g. `tl`=top-left, `mc`=mid-center
- `fc` — element has keyboard focus (K action types here)
- `ed` — element is editable (text input)
- `sr` — element is scrollable
- `hd` — element supports long-press
- `ds` — element is disabled
- `ck`/`uc` — checkbox checked/unchecked
- `sl` — element is selected
- `pw` — password field
- `...above (N nodes, scroll up)` / `...below (N nodes, scroll down)` — off-screen content indicators

### Tree Compression Rules

The raw Android accessibility tree is compressed before being passed to the model:

1. **Node filtering** — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up)
2. **Off-screen filtering** — nodes fully outside the screen bounds (`x2<=0`, `y1>=screen_height`, etc.) are excluded; replaced with `...above (N)` / `...below (N)` scroll indicators
3. **Sibling deduplication** — identical sibling subtrees are rendered only once (handles repeated list items)
4. **Multi-window deduplication** — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts)
5. **Largest-first ordering** — when multiple roots exist, the most complete window (most lines) is rendered first
6. **Element indexing** — only `clickable=true AND enabled=true` nodes get a numeric index `[n]`. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position
7. **Type abbreviation** — class names are mapped to short tags (e.g. `android.widget.Button``bt`)
8. **Zone encoding** — element center is bucketed into a 3×3 grid zone string (`tl/tc/tr/ml/mc/mr/bl/bc/br`)
9. **Label selection**`text` is preferred; falls back to `content_desc`; falls back to `resource_id` (last component after `/`)
10. **Hard truncation** — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens

---

## Output Format

A single action code (one forward pass, greedy decode):

| Code | Action | Example |
|---|---|---|
| `T{n}` | Tap element n | `T7` |
| `P{n}` | Long-press element n | `P3` |
| `K {text}` | Type text into focused field | `K hello world` |
| `U` | Scroll up | `U` |
| `D` | Scroll down | `D` |
| `L` | Scroll left | `L` |
| `R` | Scroll right | `R` |
| `B` | System back | `B` |
| `H` | Home button | `H` |
| `W` | Wait (screen loading) | `W` |
| `F` | Done (task complete) | `F` |
| `I` | Impossible (task cannot complete) | `I` |

Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS.

---

## Inference-Time Error Recovery

The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. `B4`, `W3`, `T some text`). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt:

```
Next action:
'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again:
```

This zero-shot correction resolves the majority of format errors without additional training.

---

## Performance (step 656)

Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions):

| Metric | Last 20 steps | Last 50 steps | All (102 steps) |
|---|---|---|---|
| Overall accuracy | 95.0% | 92.0% | 88.2% |
| Element index accuracy | 93.3% | 88.6% | 84.6% |

**Action type breakdown (last 20 steps):**

| Action | Accuracy |
|---|---|
| tap (T) | 93% |
| scroll (U/D/L/R) | 100% |
| back (B) | 100% |
| type (K) | 100% |
| wait (W) | 100% |

Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL.

---

## Training Process

**Base model:** `google/gemma-4-E4B-it` (4B MoE, 4-bit quantized during training via bitsandbytes)

**LoRA config:**
- `r=64`, `alpha=32`, `dropout=0.05`
- Target modules: all linear layers in the transformer

**Training data:** AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory.

**Key training decisions:**
- No label smoothing — removed after identifying it softened action type gradients
- `accum_steps=1` — every sample is its own gradient update (maximum signal density)
- `lr=5e-5`, cosine schedule
- Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward
- Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse)
- Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance
- Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches

**Training infrastructure:**
- Single RTX 3060 12GB
- ~100s/step (full image + tree encoding + gradient update)
- Milestone checkpoints every ~100 steps via sentinel file

**To replicate from scratch:**
1. Download AndroidControl dataset (GCS, 20 shards, ~47GB)
2. Preprocess with `scripts/preprocess_a11y.py` to extract accessibility trees
3. Train: `py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current`
4. Resume: `py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current`
5. Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training)

---

## Loading the Adapter

```python
from transformers import AutoProcessor
from peft import PeftModel
from transformers import Gemma4ForConditionalGeneration
import torch

base = Gemma4ForConditionalGeneration.from_pretrained(
    "google/gemma-4-E4B-it",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora")
processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora")
model.eval()
```

---

## Roadmap

- [x] SFT on AndroidControl (~88-95% single-step accuracy)
- [x] Inference-time error recovery (format validator + correction blurb)
- [ ] RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery
- [ ] Error recovery fine-tuning on collected failure cases