Instructions to use dmitchelljackson/cerebellum-e4b-lora with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use dmitchelljackson/cerebellum-e4b-lora with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("google/gemma-4-E4B-it") model = PeftModel.from_pretrained(base_model, "dmitchelljackson/cerebellum-e4b-lora") - Notebooks
- Google Colab
- Kaggle
| language: en | |
| license: apache-2.0 | |
| base_model: google/gemma-4-E4B-it | |
| tags: | |
| - android | |
| - ui-automation | |
| - accessibility | |
| - lora | |
| - peft | |
| # Cerebellum — Android UI Action Predictor | |
| LoRA adapter on top of `google/gemma-4-E4B-it` that predicts the next Android UI action given a screenshot and accessibility tree. | |
| **Architecture:** The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model. | |
| --- | |
| ## What It Does | |
| Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next. | |
| --- | |
| ## Input Format | |
| The model uses a chat-style prompt (Gemma4 format). The user turn is structured as: | |
| ``` | |
| Task: {goal} | |
| Step 1 (past): <|image|> -> {action_text} | |
| Step 2 (past): <|image|> -> {action_text} | |
| ... | |
| Current screen: <|image|> | |
| {compressed_accessibility_tree} | |
| [n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here) | |
| Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible | |
| Next action: | |
| ``` | |
| **Inputs:** | |
| - `goal` — natural language task description (e.g. "Open the settings app and enable dark mode") | |
| - `history` — up to 4 past (screenshot, action) pairs; can be empty | |
| - `current screenshot` — PIL image of the current screen, resized to 896px on the long edge | |
| - `compressed_accessibility_tree` — compact text representation of the UI element tree (see below) | |
| ### Accessibility Tree Format | |
| Each interactive element is one line: | |
| ``` | |
| [0 btn tl] Settings | |
| [1 ed mc fc=focused] Search... | |
| [2 btn sr tr] More options | |
| ``` | |
| Fields per element: | |
| - `[n]` — element index (used in action codes) | |
| - type: `bt`=Button, `ed`=EditText, `tx`=TextView, `im`=ImageView, `ck`=CheckBox, `sw`=Switch, `rd`=RadioButton, `sp`=Spinner, `sc`=ScrollView, `ls`=ListView/RecyclerView, `bar`=Toolbar, `tab`=TabLayout, `dw`=DrawerLayout, `vw`=other | |
| - zone: screen position of element center — row (`t`=top, `m`=mid, `b`=bottom) + col (`l`=left, `c`=center, `r`=right), e.g. `tl`=top-left, `mc`=mid-center | |
| - `fc` — element has keyboard focus (K action types here) | |
| - `ed` — element is editable (text input) | |
| - `sr` — element is scrollable | |
| - `hd` — element supports long-press | |
| - `ds` — element is disabled | |
| - `ck`/`uc` — checkbox checked/unchecked | |
| - `sl` — element is selected | |
| - `pw` — password field | |
| - `...above (N nodes, scroll up)` / `...below (N nodes, scroll down)` — off-screen content indicators | |
| ### Tree Compression Rules | |
| The raw Android accessibility tree is compressed before being passed to the model: | |
| 1. **Node filtering** — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up) | |
| 2. **Off-screen filtering** — nodes fully outside the screen bounds (`x2<=0`, `y1>=screen_height`, etc.) are excluded; replaced with `...above (N)` / `...below (N)` scroll indicators | |
| 3. **Sibling deduplication** — identical sibling subtrees are rendered only once (handles repeated list items) | |
| 4. **Multi-window deduplication** — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts) | |
| 5. **Largest-first ordering** — when multiple roots exist, the most complete window (most lines) is rendered first | |
| 6. **Element indexing** — only `clickable=true AND enabled=true` nodes get a numeric index `[n]`. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position | |
| 7. **Type abbreviation** — class names are mapped to short tags (e.g. `android.widget.Button` → `bt`) | |
| 8. **Zone encoding** — element center is bucketed into a 3×3 grid zone string (`tl/tc/tr/ml/mc/mr/bl/bc/br`) | |
| 9. **Label selection** — `text` is preferred; falls back to `content_desc`; falls back to `resource_id` (last component after `/`) | |
| 10. **Hard truncation** — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens | |
| --- | |
| ## Output Format | |
| A single action code (one forward pass, greedy decode): | |
| | Code | Action | Example | | |
| |---|---|---| | |
| | `T{n}` | Tap element n | `T7` | | |
| | `P{n}` | Long-press element n | `P3` | | |
| | `K {text}` | Type text into focused field | `K hello world` | | |
| | `U` | Scroll up | `U` | | |
| | `D` | Scroll down | `D` | | |
| | `L` | Scroll left | `L` | | |
| | `R` | Scroll right | `R` | | |
| | `B` | System back | `B` | | |
| | `H` | Home button | `H` | | |
| | `W` | Wait (screen loading) | `W` | | |
| | `F` | Done (task complete) | `F` | | |
| | `I` | Impossible (task cannot complete) | `I` | | |
| Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS. | |
| --- | |
| ## Inference-Time Error Recovery | |
| The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. `B4`, `W3`, `T some text`). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt: | |
| ``` | |
| Next action: | |
| 'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again: | |
| ``` | |
| This zero-shot correction resolves the majority of format errors without additional training. | |
| --- | |
| ## Performance (step 656) | |
| Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions): | |
| | Metric | Last 20 steps | Last 50 steps | All (102 steps) | | |
| |---|---|---|---| | |
| | Overall accuracy | 95.0% | 92.0% | 88.2% | | |
| | Element index accuracy | 93.3% | 88.6% | 84.6% | | |
| **Action type breakdown (last 20 steps):** | |
| | Action | Accuracy | | |
| |---|---| | |
| | tap (T) | 93% | | |
| | scroll (U/D/L/R) | 100% | | |
| | back (B) | 100% | | |
| | type (K) | 100% | | |
| | wait (W) | 100% | | |
| Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL. | |
| --- | |
| ## Training Process | |
| **Base model:** `google/gemma-4-E4B-it` (4B MoE, 4-bit quantized during training via bitsandbytes) | |
| **LoRA config:** | |
| - `r=64`, `alpha=32`, `dropout=0.05` | |
| - Target modules: all linear layers in the transformer | |
| **Training data:** AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory. | |
| **Key training decisions:** | |
| - No label smoothing — removed after identifying it softened action type gradients | |
| - `accum_steps=1` — every sample is its own gradient update (maximum signal density) | |
| - `lr=5e-5`, cosine schedule | |
| - Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward | |
| - Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse) | |
| - Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance | |
| - Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches | |
| **Training infrastructure:** | |
| - Single RTX 3060 12GB | |
| - ~100s/step (full image + tree encoding + gradient update) | |
| - Milestone checkpoints every ~100 steps via sentinel file | |
| **To replicate from scratch:** | |
| 1. Download AndroidControl dataset (GCS, 20 shards, ~47GB) | |
| 2. Preprocess with `scripts/preprocess_a11y.py` to extract accessibility trees | |
| 3. Train: `py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current` | |
| 4. Resume: `py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current` | |
| 5. Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training) | |
| --- | |
| ## Loading the Adapter | |
| ```python | |
| from transformers import AutoProcessor | |
| from peft import PeftModel | |
| from transformers import Gemma4ForConditionalGeneration | |
| import torch | |
| base = Gemma4ForConditionalGeneration.from_pretrained( | |
| "google/gemma-4-E4B-it", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora") | |
| processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora") | |
| model.eval() | |
| ``` | |
| --- | |
| ## Roadmap | |
| - [x] SFT on AndroidControl (~88-95% single-step accuracy) | |
| - [x] Inference-time error recovery (format validator + correction blurb) | |
| - [ ] RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery | |
| - [ ] Error recovery fine-tuning on collected failure cases | |