FastViT-JAX (Weights Port)

This repository provides weight porting and inference utilities for using Apple FastViT models in JAX/Flax.

This is not a reimplementation of FastViT and code does not cocnern with training.


Quick Start

If ou just want the weights, you can load a specific model variant (e.g., fastvit_sa12) directly from Python without cloning the repository.

  1. Install dependencies:

    pip install huggingface_hub flax orbax-checkpoint
    
  2. Load the Model:

    import sys
    from huggingface_hub import snapshot_download
    
    # Download code and specific model weights (skips other checkpoints)
    repo_path = snapshot_download(
        repo_id="SilverGrace-26/fastvit-jax-weights",
        allow_patterns=[
            "flax_models/**",                # Downloads the model architecture code
            "weights/orbax/fastvit_sa12/**"  # Downloads ONLY the sa12 weights
        ]
    )
    
    # Add the downloaded folder to python path so we can import the model
    sys.path.append(repo_path)
    
    #Now import your code as if it were local
    from flax_models.fastvit import FastViT
    

What it provides

  • A JAX/Flax forward-pass implementation compatible with Apple’s FastViT
  • Utilities to:
    • Load official PyTorch FastViT checkpoints
    • Apply reparameterization / fusion
    • Convert fused weights to JAX/Flax format
    • Verify numerical parity between PyTorch and JAX inference
  • A self-contained setup for testing and validation
  • Tested JAX/Flax weights all within diff threshold of 1e-4

Note

The original model was trained on ImageNet-1K but I have used ImageNette to validate due to its small size. This does not affect our parity testing. Users may use dataset of their own choice.


Usage

If you want to use the provided code to test the provided weights (you will need git lfs to clone the repo):

  1. Extract the dataset with tar -xzf imagenette-320-val

  2. Use the command :

    • uv run python inference_test_random.py --model {model_name} --load-orbax
    • uv run python inference_test_real.py --model {model_name} --dataset-dir /path/to/your/validation_set/if_not_default --load-orbax

Vendored Apple files

To make this repository self-contained and easy to test, three files under pytorch_models/ are copied verbatim from Apple’s official FastViT repository, with minimal modifications (import paths only).

These files are required for:

  • Loading official FastViT checkpoints
  • Applying the original reparameterization (fusion) logic

Each vendored file:

  • Retains Apple’s original copyright notice
  • Is covered by Apple’s FastViT license (see LICENSE)

All other files in this repository are original.


Weight porting workflow

At a high level, the process is:

  1. Load an official PyTorch FastViT checkpoint
  2. Instantiate the training-time architecture
  3. Apply reparameterization / fusion to obtain inference-time weights
  4. Export the fused weights
  5. Convert the fused weights into JAX/Flax format
  6. Run inference in both frameworks and compare outputs

This ensures that the JAX model is functionally equivalent to the original PyTorch model.


Acknowledgments

  • Apple for the original FastViT implementation.
  • Gemini 3 Pro and Claude 4.5 Sonnet for assistance with code generation, debugging, and JAX/Flax translation patterns.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support