Left vs. Right Alignment, A case study from porting LeRobot PI05's subtask prediction from JAX to PyTorch

Community Article
Published July 1, 2026

TL;DR: When you're working with transformer-based models, a single misalignment in your padding, attention masks, or position IDs can silently degrade performance without ever throwing an error. In my case, it was the mismatch between left-alignment (Hugging Face's default) and right-alignment (used by Physical Intelligence in JAX). Here's how I learned that the hard way while implementing subtask prediction for PI05 — and why this difference actually matters, even if (to be honest) it's hard to measure exactly how much.


The Setup: What Is PI05?

PI05 is a Vision-Language-Action (VLA) model from Physical Intelligence. It takes images and a text prompt, processes them through a PaliGemma vision-language model, and then uses a separate Gemma expert to generate robot actions via flow matching.

The architecture looks like this:

graph LR
    I[Images] --> SE[SigLIP Encoder]
    T[Text Prompt] --> TE[Token Embedding]
    SE --> CAT[Concatenate Embeddings]
    TE --> CAT
    CAT --> VLM[PaliGemma LLM]
    VLM --> EXP[Gemma Expert]
    EXP --> A[Actions]

How I Got Here: Implementing the Subtask Layer

PI05 is the upgraded successor to the PI0 model, and one of its most important new features is subtask prediction — an autoregressive decoding step where the VLM generates a short text description of the current subtask (e.g., "reach for the cup") before the action head kicks in.

This subtask conditions the action generation and, according to the paper, dramatically improves performance on long-horizon tasks. (I'll be honest: on a low budget, it's hard to verify that claim yourself, but the idea is compelling.)

The subtask prediction reuses the exact same embed_prefix pipeline — images + text → embeddings → pad masks → attention masks → 2D attention mask — and then runs autoregressive token generation from the VLM's language head.

Now, here's where things get personal. The LeRobot repo has a PyTorch port of PI05, but it's essentially a port of PI0 — subtask prediction is not included. If you want that feature, you have to implement it yourself.

So I did. I wrote the code, trained the model on a decent labeled dataset, and ran inference. The result? Nonsense words. Garbled tokens, gibberish subtasks — completely unusable.

At that point you start digging. There are many possible sources of failure: bad data, wrong tokenizer config, incorrect attention masking, KV cache mismanagement. But one of the most challenging — and, in retrospect, most educational — was the alignment convention.

I discovered that PI05 was trained with right-alignment in JAX. The LeRobot PyTorch port, being built on Hugging Face's ecosystem, uses left-alignment. I had a hunch this could be a big problem — and it was.

(I can't share the full repo since it's private, but the key implementation file is linked at the end of this post.)

Two Worlds, Two Alignments

Here's the core tension:

Physical Intelligence (JAX) Hugging Face (PyTorch)
Padding alignment Right-aligned Left-aligned
Position IDs Derived from right-aligned cumsum Derived from left-aligned cumsum
Attention mask Built assuming right-padded tokens Built assuming left-padded tokens

In JAX, when you have a variable-length sequence like [x₁, x₂, x₃] that needs to fit into a fixed-length tensor of size T=6, it gets padded on the left:

JAX (right-aligned):  [0, 0, 0, x₁, x₂, x₃]

In PyTorch / Hugging Face, the same sequence gets padded on the right:

PyTorch (left-aligned): [x₁, x₂, x₃, 0, 0, 0]

This looks like a trivial cosmetic difference. It is not.


Walking Through the Pipeline

Let's trace what actually happens to your tensors step by step. We'll use a batch size of B, a fixed sequence length of T, and embedding dimension D_emb. The final tensor flowing into the transformer always has shape:

(B,T,Demb) (B, T, D_{\text{emb}})

Step 1: Images → Embeddings

Each image goes through SigLIP and produces num_img_embs embedding vectors. For K camera views, we concatenate them:

I = [I_top | I_left | I_right]

Missing cameras are padded with -1 pixels and get a mask of 0. Present cameras get a mask of 1:

m_img = [1, 1, …, 1, 0, 0, …, 0]
         └── B present ──┘ └── 2B missing ──┘

The image mask gets expanded to match the embedding dimension:

Pimg=mimg[:,None].expand(B,num_img_embs) P_{\text{img}} = m_{\text{img}}[:, \text{None}].\text{expand}(B, \text{num\_img\_embs})

Step 2: Text → Tokens

The tokenizer converts the prompt into token IDs and produces a boolean mask:

tok   = [x₁, x₂, …, xₙ, 0, 0, 0]
m_tok = [ 1,  1, …,  1, 0, 0, 0]

Step 3: Concatenate Everything

Now we concatenate image embeddings and text embeddings along the sequence dimension, and do the same for their masks:

embs      = [img_embs | txt_embs]          shape: (B, T, D_emb)
pad_mask  = [img_mask | txt_mask]          shape: (B, T)
att_mask  = [0, 0, …, 0]                   shape: (B, T)

The att_mask here is the prefix attention mask0 means "this token is part of the bidirectional prefix" and 1 means "this token starts causal attention." For the prefix (images + prompt), everything is 0 because all tokens can attend to all others.

Step 4: The Critical Function — make_att_2d_masks

This is where alignment really matters. The function that builds the 2D attention matrix works like this:

def make_att_2d_masks(pad_masks, att_masks):
    cumsum = torch.cumsum(att_masks, dim=1)         # (B, T)
    att_2d = cumsum[:, None, :] <= cumsum[:, :, None]  # (B, T, T)
    pad_2d = pad_masks[:, None, :] * pad_masks[:, :, None]  # (B, T, T)
    return att_2d & pad_2d

The cumsum on att_masks creates causal groups. Tracks with the same cumulative sum value can attend to each other. Tracks with different values get causal masking (earlier groups can be seen by later groups, but not vice versa).

For a pure prefix, att_mask = [0, 0, 0, 0, 0, 0], so cumsum = [0, 0, 0, 0, 0, 0] and all tokens attend to all others — perfect bidirectional attention.

But if you introduce causal tokens (like during autoregressive decoding), the cumulative sum pattern changes, and suddenly where the real tokens sit in the tensor matters a lot.

Step 5: Position IDs

Position IDs are computed as:

position_ids = torch.cumsum(pad_masks, dim=1) - 1

With right-alignment (JAX):

pad_mask  = [0, 0, 0, 1, 1, 1]
cumsum    = [0, 0, 0, 1, 2, 3]
pos_ids   = [-1, -1, -1, 0, 1, 2]

With left-alignment (PyTorch):

pad_mask  = [1, 1, 1, 0, 0, 0]
cumsum    = [1, 2, 3, 3, 3, 3]
pos_ids   = [0, 1, 2, 2, 2, 2]

These are completely different position embeddings being fed into the model. A model trained with JAX right-alignment expects position -1 for padding tokens. Feed it 2 instead and you've silently corrupted the positional signal.


Where It Bites: The Subtask Decoding Loop

Here's the concrete scenario. During subtask prediction, you:

  1. Run the prefix (images + prompt) through the VLM to get the first token embedding
  2. Sample a token autoregressively
  3. Feed that token back in, with an extended attention mask and an incremented position ID

In the JAX implementation (from Physical Intelligence's openpi repo), there's an explicit re-alignment step:

# left to right align all input token sequences
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
    prefix_token_embeddings, prefix_mask, prefix_attn_mask
)

This function shifts all the real tokens to the left and recomputes masks. Why? Because the autoregressive decoding loop builds its attention mask by slicing from prefix_start — and prefix_start depends on knowing exactly where the real tokens begin:

prefill_size = prefix_token_embeddings.shape[1]
prefill_len = jnp.sum(prefix_mask, axis=-1)
prefix_start = prefill_size - prefill_len  # right-aligned: where tokens start

If you port this to PyTorch without handling the alignment, prefix_start becomes 0 (because tokens start at position 0 in left-alignment) but the attention mask slicing logic was designed for right-aligned offsets. The mask you construct during the autoregressive loop will be subtly wrong — tokens might attend to padding, or worse, valid tokens might be masked out from attending to each other.

The model won't crash. It won't even produce NaN. It'll just perform worse, and you'll spend days wondering why your ported model doesn't match the original's results.


The Fix

When porting from a JAX/right-aligned codebase to PyTorch/Hugging Face, you have a few options:

  1. Mirror the original alignment: Keep everything right-aligned in PyTorch. This means padding on the left, computing position IDs accordingly, and being explicit about it in your code. This is the safest approach if you're loading pretrained weights.

  2. Re-align explicitly: Do what the JAX code does — run a left_to_right_align pass before autoregressive decoding. This converts the internal representation to left-aligned for the decoding loop, then you can use standard HF-style position IDs.

  3. Adapt the mask-building logic: Rewrite the attention mask construction to be alignment-agnostic by computing offsets from the pad mask's sum rather than assuming a particular padding side.

The key insight is: be explicit about which convention you're using and verify it at every step. Add assertions:

# Verify alignment convention
assert pad_mask[:, 0].all(), "Expected left-aligned (real tokens at start)"
# or
assert not pad_mask[:, 0].any(), "Expected right-aligned (padding at start)"

Why This Matters Beyond PI05

This isn't just a PI05 problem. Any time you:

  • Load pretrained weights from a JAX model into PyTorch
  • Mix components trained with different alignment conventions
  • Implement custom attention patterns (prefix-LM, block-causal, etc.)
  • Do autoregressive decoding with KV caches and variable-length prefixes

…the alignment of your padding will determine whether your position IDs, attention masks, and KV cache offsets are correct.

The scary part is that alignment bugs are silent. Your loss might look fine during training. Your model might generate plausible-looking outputs. But you're leaving performance on the table — and in robotics, where PI05 is deployed, that could mean the difference between grasping the cup and knocking it over.


Lessons Learned

  1. Trace the shapes. Before trusting any model port, manually trace (B, T, D) through every step with a toy batch of size 2. Print the pad masks, position IDs, and attention masks at each stage.

  2. Alignment is a first-class design decision. Don't treat it as an implementation detail. Document which convention each component expects.

  3. JAX and PyTorch have different defaults for good reasons. JAX's functional style and vmap/pmap patterns favor right-alignment. PyTorch's eager execution and HF's generate() API favor left-alignment. Neither is wrong — but mixing them without adaptation is.

  4. The Hugging Face ecosystem is left-aligned. If you're contributing a model to transformers, left-align your sequences. If you're porting a model that was trained right-aligned, handle the conversion explicitly and document it.

The code where I implemented such thing is on: PI05 alignment code

Community

Sign up or log in to comment