--- language: en license: apache-2.0 base_model: google/gemma-4-E4B-it tags: - android - ui-automation - accessibility - lora - peft --- # Cerebellum — Android UI Action Predictor LoRA adapter on top of `google/gemma-4-E4B-it` that predicts the next Android UI action given a screenshot and accessibility tree. **Architecture:** The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model. --- ## What It Does Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next. --- ## Input Format The model uses a chat-style prompt (Gemma4 format). The user turn is structured as: ``` Task: {goal} Step 1 (past): <|image|> -> {action_text} Step 2 (past): <|image|> -> {action_text} ... Current screen: <|image|> {compressed_accessibility_tree} [n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here) Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible Next action: ``` **Inputs:** - `goal` — natural language task description (e.g. "Open the settings app and enable dark mode") - `history` — up to 4 past (screenshot, action) pairs; can be empty - `current screenshot` — PIL image of the current screen, resized to 896px on the long edge - `compressed_accessibility_tree` — compact text representation of the UI element tree (see below) ### Accessibility Tree Format Each interactive element is one line: ``` [0 btn tl] Settings [1 ed mc fc=focused] Search... [2 btn sr tr] More options ``` Fields per element: - `[n]` — element index (used in action codes) - type: `bt`=Button, `ed`=EditText, `tx`=TextView, `im`=ImageView, `ck`=CheckBox, `sw`=Switch, `rd`=RadioButton, `sp`=Spinner, `sc`=ScrollView, `ls`=ListView/RecyclerView, `bar`=Toolbar, `tab`=TabLayout, `dw`=DrawerLayout, `vw`=other - zone: screen position of element center — row (`t`=top, `m`=mid, `b`=bottom) + col (`l`=left, `c`=center, `r`=right), e.g. `tl`=top-left, `mc`=mid-center - `fc` — element has keyboard focus (K action types here) - `ed` — element is editable (text input) - `sr` — element is scrollable - `hd` — element supports long-press - `ds` — element is disabled - `ck`/`uc` — checkbox checked/unchecked - `sl` — element is selected - `pw` — password field - `...above (N nodes, scroll up)` / `...below (N nodes, scroll down)` — off-screen content indicators ### Tree Compression Rules The raw Android accessibility tree is compressed before being passed to the model: 1. **Node filtering** — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up) 2. **Off-screen filtering** — nodes fully outside the screen bounds (`x2<=0`, `y1>=screen_height`, etc.) are excluded; replaced with `...above (N)` / `...below (N)` scroll indicators 3. **Sibling deduplication** — identical sibling subtrees are rendered only once (handles repeated list items) 4. **Multi-window deduplication** — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts) 5. **Largest-first ordering** — when multiple roots exist, the most complete window (most lines) is rendered first 6. **Element indexing** — only `clickable=true AND enabled=true` nodes get a numeric index `[n]`. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position 7. **Type abbreviation** — class names are mapped to short tags (e.g. `android.widget.Button` → `bt`) 8. **Zone encoding** — element center is bucketed into a 3×3 grid zone string (`tl/tc/tr/ml/mc/mr/bl/bc/br`) 9. **Label selection** — `text` is preferred; falls back to `content_desc`; falls back to `resource_id` (last component after `/`) 10. **Hard truncation** — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens --- ## Output Format A single action code (one forward pass, greedy decode): | Code | Action | Example | |---|---|---| | `T{n}` | Tap element n | `T7` | | `P{n}` | Long-press element n | `P3` | | `K {text}` | Type text into focused field | `K hello world` | | `U` | Scroll up | `U` | | `D` | Scroll down | `D` | | `L` | Scroll left | `L` | | `R` | Scroll right | `R` | | `B` | System back | `B` | | `H` | Home button | `H` | | `W` | Wait (screen loading) | `W` | | `F` | Done (task complete) | `F` | | `I` | Impossible (task cannot complete) | `I` | Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS. --- ## Inference-Time Error Recovery The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. `B4`, `W3`, `T some text`). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt: ``` Next action: 'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again: ``` This zero-shot correction resolves the majority of format errors without additional training. --- ## Performance (step 656) Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions): | Metric | Last 20 steps | Last 50 steps | All (102 steps) | |---|---|---|---| | Overall accuracy | 95.0% | 92.0% | 88.2% | | Element index accuracy | 93.3% | 88.6% | 84.6% | **Action type breakdown (last 20 steps):** | Action | Accuracy | |---|---| | tap (T) | 93% | | scroll (U/D/L/R) | 100% | | back (B) | 100% | | type (K) | 100% | | wait (W) | 100% | Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL. --- ## Training Process **Base model:** `google/gemma-4-E4B-it` (4B MoE, 4-bit quantized during training via bitsandbytes) **LoRA config:** - `r=64`, `alpha=32`, `dropout=0.05` - Target modules: all linear layers in the transformer **Training data:** AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory. **Key training decisions:** - No label smoothing — removed after identifying it softened action type gradients - `accum_steps=1` — every sample is its own gradient update (maximum signal density) - `lr=5e-5`, cosine schedule - Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward - Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse) - Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance - Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches **Training infrastructure:** - Single RTX 3060 12GB - ~100s/step (full image + tree encoding + gradient update) - Milestone checkpoints every ~100 steps via sentinel file **To replicate from scratch:** 1. Download AndroidControl dataset (GCS, 20 shards, ~47GB) 2. Preprocess with `scripts/preprocess_a11y.py` to extract accessibility trees 3. Train: `py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current` 4. Resume: `py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current` 5. Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training) --- ## Loading the Adapter ```python from transformers import AutoProcessor from peft import PeftModel from transformers import Gemma4ForConditionalGeneration import torch base = Gemma4ForConditionalGeneration.from_pretrained( "google/gemma-4-E4B-it", torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora") processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora") model.eval() ``` --- ## Roadmap - [x] SFT on AndroidControl (~88-95% single-step accuracy) - [x] Inference-time error recovery (format validator + correction blurb) - [ ] RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery - [ ] Error recovery fine-tuning on collected failure cases