FastViT-JAX (Weights Port)
This repository provides weight porting and inference utilities for using Apple FastViT models in JAX/Flax.
This is not a reimplementation of FastViT and code does not cocnern with training.
Quick Start
If ou just want the weights, you can load a specific model variant (e.g., fastvit_sa12) directly from Python without cloning the repository.
Install dependencies:
pip install huggingface_hub flax orbax-checkpointLoad the Model:
import sys from huggingface_hub import snapshot_download # Download code and specific model weights (skips other checkpoints) repo_path = snapshot_download( repo_id="SilverGrace-26/fastvit-jax-weights", allow_patterns=[ "flax_models/**", # Downloads the model architecture code "weights/orbax/fastvit_sa12/**" # Downloads ONLY the sa12 weights ] ) # Add the downloaded folder to python path so we can import the model sys.path.append(repo_path) #Now import your code as if it were local from flax_models.fastvit import FastViT
What it provides
- A JAX/Flax forward-pass implementation compatible with Apple’s FastViT
- Utilities to:
- Load official PyTorch FastViT checkpoints
- Apply reparameterization / fusion
- Convert fused weights to JAX/Flax format
- Verify numerical parity between PyTorch and JAX inference
- A self-contained setup for testing and validation
- Tested JAX/Flax weights all within diff threshold of 1e-4
Note
The original model was trained on ImageNet-1K but I have used ImageNette to validate due to its small size. This does not affect our parity testing. Users may use dataset of their own choice.
Usage
If you want to use the provided code to test the provided weights (you will need git lfs to clone the repo):
Extract the dataset with
tar -xzf imagenette-320-valUse the command :
uv run python inference_test_random.py --model {model_name} --load-orbaxuv run python inference_test_real.py --model {model_name} --dataset-dir /path/to/your/validation_set/if_not_default --load-orbax
Vendored Apple files
To make this repository self-contained and easy to test, three files
under pytorch_models/ are copied verbatim from Apple’s official FastViT
repository, with minimal modifications (import paths only).
These files are required for:
- Loading official FastViT checkpoints
- Applying the original reparameterization (fusion) logic
Each vendored file:
- Retains Apple’s original copyright notice
- Is covered by Apple’s FastViT license (see
LICENSE)
All other files in this repository are original.
Weight porting workflow
At a high level, the process is:
- Load an official PyTorch FastViT checkpoint
- Instantiate the training-time architecture
- Apply reparameterization / fusion to obtain inference-time weights
- Export the fused weights
- Convert the fused weights into JAX/Flax format
- Run inference in both frameworks and compare outputs
This ensures that the JAX model is functionally equivalent to the original PyTorch model.
Acknowledgments
- Apple for the original FastViT implementation.
- Gemini 3 Pro and Claude 4.5 Sonnet for assistance with code generation, debugging, and JAX/Flax translation patterns.