cerebellum-e4b-lora / README.md
dmitchelljackson's picture
Upload README.md with huggingface_hub
152c46e verified
---
language: en
license: apache-2.0
base_model: google/gemma-4-E4B-it
tags:
- android
- ui-automation
- accessibility
- lora
- peft
---
# Cerebellum — Android UI Action Predictor
LoRA adapter on top of `google/gemma-4-E4B-it` that predicts the next Android UI action given a screenshot and accessibility tree.
**Architecture:** The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model.
---
## What It Does
Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next.
---
## Input Format
The model uses a chat-style prompt (Gemma4 format). The user turn is structured as:
```
Task: {goal}
Step 1 (past): <|image|> -> {action_text}
Step 2 (past): <|image|> -> {action_text}
...
Current screen: <|image|>
{compressed_accessibility_tree}
[n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here)
Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible
Next action:
```
**Inputs:**
- `goal` — natural language task description (e.g. "Open the settings app and enable dark mode")
- `history` — up to 4 past (screenshot, action) pairs; can be empty
- `current screenshot` — PIL image of the current screen, resized to 896px on the long edge
- `compressed_accessibility_tree` — compact text representation of the UI element tree (see below)
### Accessibility Tree Format
Each interactive element is one line:
```
[0 btn tl] Settings
[1 ed mc fc=focused] Search...
[2 btn sr tr] More options
```
Fields per element:
- `[n]` — element index (used in action codes)
- type: `bt`=Button, `ed`=EditText, `tx`=TextView, `im`=ImageView, `ck`=CheckBox, `sw`=Switch, `rd`=RadioButton, `sp`=Spinner, `sc`=ScrollView, `ls`=ListView/RecyclerView, `bar`=Toolbar, `tab`=TabLayout, `dw`=DrawerLayout, `vw`=other
- zone: screen position of element center — row (`t`=top, `m`=mid, `b`=bottom) + col (`l`=left, `c`=center, `r`=right), e.g. `tl`=top-left, `mc`=mid-center
- `fc` — element has keyboard focus (K action types here)
- `ed` — element is editable (text input)
- `sr` — element is scrollable
- `hd` — element supports long-press
- `ds` — element is disabled
- `ck`/`uc` — checkbox checked/unchecked
- `sl` — element is selected
- `pw` — password field
- `...above (N nodes, scroll up)` / `...below (N nodes, scroll down)` — off-screen content indicators
### Tree Compression Rules
The raw Android accessibility tree is compressed before being passed to the model:
1. **Node filtering** — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up)
2. **Off-screen filtering** — nodes fully outside the screen bounds (`x2<=0`, `y1>=screen_height`, etc.) are excluded; replaced with `...above (N)` / `...below (N)` scroll indicators
3. **Sibling deduplication** — identical sibling subtrees are rendered only once (handles repeated list items)
4. **Multi-window deduplication** — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts)
5. **Largest-first ordering** — when multiple roots exist, the most complete window (most lines) is rendered first
6. **Element indexing** — only `clickable=true AND enabled=true` nodes get a numeric index `[n]`. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position
7. **Type abbreviation** — class names are mapped to short tags (e.g. `android.widget.Button``bt`)
8. **Zone encoding** — element center is bucketed into a 3×3 grid zone string (`tl/tc/tr/ml/mc/mr/bl/bc/br`)
9. **Label selection**`text` is preferred; falls back to `content_desc`; falls back to `resource_id` (last component after `/`)
10. **Hard truncation** — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens
---
## Output Format
A single action code (one forward pass, greedy decode):
| Code | Action | Example |
|---|---|---|
| `T{n}` | Tap element n | `T7` |
| `P{n}` | Long-press element n | `P3` |
| `K {text}` | Type text into focused field | `K hello world` |
| `U` | Scroll up | `U` |
| `D` | Scroll down | `D` |
| `L` | Scroll left | `L` |
| `R` | Scroll right | `R` |
| `B` | System back | `B` |
| `H` | Home button | `H` |
| `W` | Wait (screen loading) | `W` |
| `F` | Done (task complete) | `F` |
| `I` | Impossible (task cannot complete) | `I` |
Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS.
---
## Inference-Time Error Recovery
The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. `B4`, `W3`, `T some text`). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt:
```
Next action:
'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again:
```
This zero-shot correction resolves the majority of format errors without additional training.
---
## Performance (step 656)
Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions):
| Metric | Last 20 steps | Last 50 steps | All (102 steps) |
|---|---|---|---|
| Overall accuracy | 95.0% | 92.0% | 88.2% |
| Element index accuracy | 93.3% | 88.6% | 84.6% |
**Action type breakdown (last 20 steps):**
| Action | Accuracy |
|---|---|
| tap (T) | 93% |
| scroll (U/D/L/R) | 100% |
| back (B) | 100% |
| type (K) | 100% |
| wait (W) | 100% |
Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL.
---
## Training Process
**Base model:** `google/gemma-4-E4B-it` (4B MoE, 4-bit quantized during training via bitsandbytes)
**LoRA config:**
- `r=64`, `alpha=32`, `dropout=0.05`
- Target modules: all linear layers in the transformer
**Training data:** AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory.
**Key training decisions:**
- No label smoothing — removed after identifying it softened action type gradients
- `accum_steps=1` — every sample is its own gradient update (maximum signal density)
- `lr=5e-5`, cosine schedule
- Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward
- Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse)
- Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance
- Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches
**Training infrastructure:**
- Single RTX 3060 12GB
- ~100s/step (full image + tree encoding + gradient update)
- Milestone checkpoints every ~100 steps via sentinel file
**To replicate from scratch:**
1. Download AndroidControl dataset (GCS, 20 shards, ~47GB)
2. Preprocess with `scripts/preprocess_a11y.py` to extract accessibility trees
3. Train: `py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current`
4. Resume: `py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current`
5. Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training)
---
## Loading the Adapter
```python
from transformers import AutoProcessor
from peft import PeftModel
from transformers import Gemma4ForConditionalGeneration
import torch
base = Gemma4ForConditionalGeneration.from_pretrained(
"google/gemma-4-E4B-it",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora")
processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora")
model.eval()
```
---
## Roadmap
- [x] SFT on AndroidControl (~88-95% single-step accuracy)
- [x] Inference-time error recovery (format validator + correction blurb)
- [ ] RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery
- [ ] Error recovery fine-tuning on collected failure cases