Instructions to use dmitchelljackson/cerebellum-e4b-lora with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use dmitchelljackson/cerebellum-e4b-lora with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("google/gemma-4-E4B-it") model = PeftModel.from_pretrained(base_model, "dmitchelljackson/cerebellum-e4b-lora") - Notebooks
- Google Colab
- Kaggle
File size: 8,843 Bytes
ab974a1 152c46e ab974a1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | ---
language: en
license: apache-2.0
base_model: google/gemma-4-E4B-it
tags:
- android
- ui-automation
- accessibility
- lora
- peft
---
# Cerebellum — Android UI Action Predictor
LoRA adapter on top of `google/gemma-4-E4B-it` that predicts the next Android UI action given a screenshot and accessibility tree.
**Architecture:** The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model.
---
## What It Does
Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next.
---
## Input Format
The model uses a chat-style prompt (Gemma4 format). The user turn is structured as:
```
Task: {goal}
Step 1 (past): <|image|> -> {action_text}
Step 2 (past): <|image|> -> {action_text}
...
Current screen: <|image|>
{compressed_accessibility_tree}
[n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here)
Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible
Next action:
```
**Inputs:**
- `goal` — natural language task description (e.g. "Open the settings app and enable dark mode")
- `history` — up to 4 past (screenshot, action) pairs; can be empty
- `current screenshot` — PIL image of the current screen, resized to 896px on the long edge
- `compressed_accessibility_tree` — compact text representation of the UI element tree (see below)
### Accessibility Tree Format
Each interactive element is one line:
```
[0 btn tl] Settings
[1 ed mc fc=focused] Search...
[2 btn sr tr] More options
```
Fields per element:
- `[n]` — element index (used in action codes)
- type: `bt`=Button, `ed`=EditText, `tx`=TextView, `im`=ImageView, `ck`=CheckBox, `sw`=Switch, `rd`=RadioButton, `sp`=Spinner, `sc`=ScrollView, `ls`=ListView/RecyclerView, `bar`=Toolbar, `tab`=TabLayout, `dw`=DrawerLayout, `vw`=other
- zone: screen position of element center — row (`t`=top, `m`=mid, `b`=bottom) + col (`l`=left, `c`=center, `r`=right), e.g. `tl`=top-left, `mc`=mid-center
- `fc` — element has keyboard focus (K action types here)
- `ed` — element is editable (text input)
- `sr` — element is scrollable
- `hd` — element supports long-press
- `ds` — element is disabled
- `ck`/`uc` — checkbox checked/unchecked
- `sl` — element is selected
- `pw` — password field
- `...above (N nodes, scroll up)` / `...below (N nodes, scroll down)` — off-screen content indicators
### Tree Compression Rules
The raw Android accessibility tree is compressed before being passed to the model:
1. **Node filtering** — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up)
2. **Off-screen filtering** — nodes fully outside the screen bounds (`x2<=0`, `y1>=screen_height`, etc.) are excluded; replaced with `...above (N)` / `...below (N)` scroll indicators
3. **Sibling deduplication** — identical sibling subtrees are rendered only once (handles repeated list items)
4. **Multi-window deduplication** — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts)
5. **Largest-first ordering** — when multiple roots exist, the most complete window (most lines) is rendered first
6. **Element indexing** — only `clickable=true AND enabled=true` nodes get a numeric index `[n]`. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position
7. **Type abbreviation** — class names are mapped to short tags (e.g. `android.widget.Button` → `bt`)
8. **Zone encoding** — element center is bucketed into a 3×3 grid zone string (`tl/tc/tr/ml/mc/mr/bl/bc/br`)
9. **Label selection** — `text` is preferred; falls back to `content_desc`; falls back to `resource_id` (last component after `/`)
10. **Hard truncation** — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens
---
## Output Format
A single action code (one forward pass, greedy decode):
| Code | Action | Example |
|---|---|---|
| `T{n}` | Tap element n | `T7` |
| `P{n}` | Long-press element n | `P3` |
| `K {text}` | Type text into focused field | `K hello world` |
| `U` | Scroll up | `U` |
| `D` | Scroll down | `D` |
| `L` | Scroll left | `L` |
| `R` | Scroll right | `R` |
| `B` | System back | `B` |
| `H` | Home button | `H` |
| `W` | Wait (screen loading) | `W` |
| `F` | Done (task complete) | `F` |
| `I` | Impossible (task cannot complete) | `I` |
Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS.
---
## Inference-Time Error Recovery
The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. `B4`, `W3`, `T some text`). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt:
```
Next action:
'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again:
```
This zero-shot correction resolves the majority of format errors without additional training.
---
## Performance (step 656)
Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions):
| Metric | Last 20 steps | Last 50 steps | All (102 steps) |
|---|---|---|---|
| Overall accuracy | 95.0% | 92.0% | 88.2% |
| Element index accuracy | 93.3% | 88.6% | 84.6% |
**Action type breakdown (last 20 steps):**
| Action | Accuracy |
|---|---|
| tap (T) | 93% |
| scroll (U/D/L/R) | 100% |
| back (B) | 100% |
| type (K) | 100% |
| wait (W) | 100% |
Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL.
---
## Training Process
**Base model:** `google/gemma-4-E4B-it` (4B MoE, 4-bit quantized during training via bitsandbytes)
**LoRA config:**
- `r=64`, `alpha=32`, `dropout=0.05`
- Target modules: all linear layers in the transformer
**Training data:** AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory.
**Key training decisions:**
- No label smoothing — removed after identifying it softened action type gradients
- `accum_steps=1` — every sample is its own gradient update (maximum signal density)
- `lr=5e-5`, cosine schedule
- Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward
- Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse)
- Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance
- Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches
**Training infrastructure:**
- Single RTX 3060 12GB
- ~100s/step (full image + tree encoding + gradient update)
- Milestone checkpoints every ~100 steps via sentinel file
**To replicate from scratch:**
1. Download AndroidControl dataset (GCS, 20 shards, ~47GB)
2. Preprocess with `scripts/preprocess_a11y.py` to extract accessibility trees
3. Train: `py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current`
4. Resume: `py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current`
5. Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training)
---
## Loading the Adapter
```python
from transformers import AutoProcessor
from peft import PeftModel
from transformers import Gemma4ForConditionalGeneration
import torch
base = Gemma4ForConditionalGeneration.from_pretrained(
"google/gemma-4-E4B-it",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora")
processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora")
model.eval()
```
---
## Roadmap
- [x] SFT on AndroidControl (~88-95% single-step accuracy)
- [x] Inference-time error recovery (format validator + correction blurb)
- [ ] RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery
- [ ] Error recovery fine-tuning on collected failure cases
|