fastvit-jax-weights / inference_test_random.py
SilverGrace-26's picture
Finishing
d18100d
# Copyright (c) 2025 FastViT-JAX Contributors
#
# Licensed under the MIT License. See the LICENSE file for details.
import os
import argparse
import torch
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
try:
from pytorch_models.fastvit import (
fastvit_t8 as pt_fastvit_t8,
fastvit_t12 as pt_fastvit_t12,
fastvit_s12 as pt_fastvit_s12,
fastvit_sa12 as pt_fastvit_sa12,
fastvit_sa24 as pt_fastvit_sa24,
fastvit_sa36 as pt_fastvit_sa36,
fastvit_ma36 as pt_fastvit_ma36,
)
except ImportError:
print("!!! Could not import PyTorch model. Check 'pytorch_models/fastvit.py'")
exit(1)
try:
from flax_models.fastvit import (
fastvit_t8 as flax_fastvit_t8,
fastvit_t12 as flax_fastvit_t12,
fastvit_s12 as flax_fastvit_s12,
fastvit_sa12 as flax_fastvit_sa12,
fastvit_sa24 as flax_fastvit_sa24,
fastvit_sa36 as flax_fastvit_sa36,
fastvit_ma36 as flax_fastvit_ma36,
)
except ImportError:
print("!!! Could not import Flax model. Check 'flax_models/fastvit.py'")
exit(1)
try:
from weight_conversion import convert_pytorch_to_flax
except ImportError:
print("!!! Could not import conversion script. Check 'weight_conversion.py'")
exit(1)
MODEL_REGISTRY = {
'fastvit_t8': {
'pt_model': pt_fastvit_t8,
'flax_model': flax_fastvit_t8,
'stage_config': [2, 2, 4, 2],
'embed_dims': [48, 96, 192, 384],
'token_mixers': ['repmixer'] * 4,
},
'fastvit_t12': {
'pt_model': pt_fastvit_t12,
'flax_model': flax_fastvit_t12,
'stage_config': [2, 2, 6, 2],
'embed_dims': [64, 128, 256, 512],
'token_mixers': ['repmixer'] * 4,
},
'fastvit_s12': {
'pt_model': pt_fastvit_s12,
'flax_model': flax_fastvit_s12,
'stage_config': [2, 2, 6, 2],
'embed_dims': [64, 128, 256, 512],
'token_mixers': ['repmixer'] * 4,
},
'fastvit_sa12': {
'pt_model': pt_fastvit_sa12,
'flax_model': flax_fastvit_sa12,
'stage_config': [2, 2, 6, 2],
'embed_dims': [64, 128, 256, 512],
'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'],
},
'fastvit_sa24': {
'pt_model': pt_fastvit_sa24,
'flax_model': flax_fastvit_sa24,
'stage_config': [4, 4, 12, 4],
'embed_dims': [64, 128, 256, 512],
'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'],
},
'fastvit_sa36': {
'pt_model': pt_fastvit_sa36,
'flax_model': flax_fastvit_sa36,
'stage_config': [6, 6, 18, 6],
'embed_dims': [64, 128, 256, 512],
'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'],
},
'fastvit_ma36': {
'pt_model': pt_fastvit_ma36,
'flax_model': flax_fastvit_ma36,
'stage_config': [6, 6, 18, 6],
'embed_dims': [76, 152, 304, 608],
'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'],
},
}
def print_stat(name, tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
elif hasattr(tensor, '__array__'):
tensor = np.array(tensor)
print(f"{name: <10} | Mean: {tensor.mean():.6f} | Std: {tensor.std():.6f} | "
f"Min: {tensor.min():.6f} | Max: {tensor.max():.6f}")
def main():
parser = argparse.ArgumentParser(description='Compare PyTorch and Flax models')
parser.add_argument('--model', type=str, default='fastvit_t8', choices=list(MODEL_REGISTRY.keys()))
parser.add_argument('--threshold', type=float, default=1e-4)
parser.add_argument('--load-orbax', action='store_true', help='Load from Orbax checkpoint')
args = parser.parse_args()
print("="*80)
print(f"FastViT: PyTorch vs. Flax Verification - {args.model.upper()}")
print(f"Mode: {'Orbax Loading' if args.load_orbax else 'On-the-fly Conversion'}")
print("="*80)
model_info = MODEL_REGISTRY[args.model]
pt_model_fn = model_info['pt_model']
flax_model_fn = model_info['flax_model']
stage_config = model_info['stage_config']
token_mixers = model_info['token_mixers']
np.random.seed(42)
input_shape = (1, 256, 256, 3)
x_np = np.random.normal(0, 1, input_shape).astype(np.float32)
x_pt = torch.from_numpy(x_np).permute(0, 3, 1, 2)
x_flax = jnp.array(x_np)
print("Loading PyTorch Model (Inference Mode)...")
try:
model_pt = pt_model_fn(inference_mode=True)
except TypeError:
model_pt = pt_model_fn()
model_pt.eval()
pt_weights_path = f"weights/fused/{args.model}_fused.pth"
if not os.path.exists(pt_weights_path):
print(f"!!! Fused weights missing: {pt_weights_path}")
return
checkpoint = torch.load(pt_weights_path, map_location='cpu')
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
model_pt.load_state_dict(state_dict)
with torch.no_grad():
out_pt = model_pt(x_pt)
print_stat("PyTorch", out_pt)
print("-" * 80)
print("Initializing Flax Model...")
model_fx = flax_model_fn(num_classes=1000)
init_variables = model_fx.init(jax.random.PRNGKey(0), x_flax, train=False)
if args.load_orbax:
print("Loading Orbax Checkpoint...")
orbax_dir = f"weights/orbax/{args.model}"
abs_path = os.path.abspath(orbax_dir)
if not os.path.exists(abs_path):
print(f"!!! Orbax path missing: {abs_path}")
return
abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, init_variables)
with ocp.CheckpointManager(abs_path) as mngr:
if mngr.latest_step() is None:
print("!!! No checkpoints found.")
return
restore_args = ocp.args.StandardRestore(abstract_tree)
flax_vars = mngr.restore(mngr.latest_step(), args=restore_args)
print("✓ Orbax weights loaded")
else:
print("Converting Weights from PyTorch...")
flax_vars = convert_pytorch_to_flax(state_dict, stage_config, token_mixers, verbose=False)
out_fx = model_fx.apply(flax_vars, x_flax, train=False)
res_pt = out_pt.detach().cpu().numpy()
res_fx = np.array(out_fx)
print_stat("Flax", res_fx)
diff = np.abs(res_pt - res_fx)
max_diff = diff.max()
print("-" * 80)
print(f"RESULTS: Max Diff = {max_diff:.8f}")
if max_diff < args.threshold:
print(":) SUCCESS")
else:
print("!!! FAILURE")
print("="*80)
if __name__ == "__main__":
main()