| |
| |
| |
|
|
| import os |
| import argparse |
| import torch |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import orbax.checkpoint as ocp |
|
|
| try: |
| from pytorch_models.fastvit import ( |
| fastvit_t8 as pt_fastvit_t8, |
| fastvit_t12 as pt_fastvit_t12, |
| fastvit_s12 as pt_fastvit_s12, |
| fastvit_sa12 as pt_fastvit_sa12, |
| fastvit_sa24 as pt_fastvit_sa24, |
| fastvit_sa36 as pt_fastvit_sa36, |
| fastvit_ma36 as pt_fastvit_ma36, |
| ) |
| except ImportError: |
| print("!!! Could not import PyTorch model. Check 'pytorch_models/fastvit.py'") |
| exit(1) |
|
|
| try: |
| from flax_models.fastvit import ( |
| fastvit_t8 as flax_fastvit_t8, |
| fastvit_t12 as flax_fastvit_t12, |
| fastvit_s12 as flax_fastvit_s12, |
| fastvit_sa12 as flax_fastvit_sa12, |
| fastvit_sa24 as flax_fastvit_sa24, |
| fastvit_sa36 as flax_fastvit_sa36, |
| fastvit_ma36 as flax_fastvit_ma36, |
| ) |
| except ImportError: |
| print("!!! Could not import Flax model. Check 'flax_models/fastvit.py'") |
| exit(1) |
|
|
| try: |
| from weight_conversion import convert_pytorch_to_flax |
| except ImportError: |
| print("!!! Could not import conversion script. Check 'weight_conversion.py'") |
| exit(1) |
|
|
|
|
| MODEL_REGISTRY = { |
| 'fastvit_t8': { |
| 'pt_model': pt_fastvit_t8, |
| 'flax_model': flax_fastvit_t8, |
| 'stage_config': [2, 2, 4, 2], |
| 'embed_dims': [48, 96, 192, 384], |
| 'token_mixers': ['repmixer'] * 4, |
| }, |
| 'fastvit_t12': { |
| 'pt_model': pt_fastvit_t12, |
| 'flax_model': flax_fastvit_t12, |
| 'stage_config': [2, 2, 6, 2], |
| 'embed_dims': [64, 128, 256, 512], |
| 'token_mixers': ['repmixer'] * 4, |
| }, |
| 'fastvit_s12': { |
| 'pt_model': pt_fastvit_s12, |
| 'flax_model': flax_fastvit_s12, |
| 'stage_config': [2, 2, 6, 2], |
| 'embed_dims': [64, 128, 256, 512], |
| 'token_mixers': ['repmixer'] * 4, |
| }, |
| 'fastvit_sa12': { |
| 'pt_model': pt_fastvit_sa12, |
| 'flax_model': flax_fastvit_sa12, |
| 'stage_config': [2, 2, 6, 2], |
| 'embed_dims': [64, 128, 256, 512], |
| 'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'], |
| }, |
| 'fastvit_sa24': { |
| 'pt_model': pt_fastvit_sa24, |
| 'flax_model': flax_fastvit_sa24, |
| 'stage_config': [4, 4, 12, 4], |
| 'embed_dims': [64, 128, 256, 512], |
| 'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'], |
| }, |
| 'fastvit_sa36': { |
| 'pt_model': pt_fastvit_sa36, |
| 'flax_model': flax_fastvit_sa36, |
| 'stage_config': [6, 6, 18, 6], |
| 'embed_dims': [64, 128, 256, 512], |
| 'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'], |
| }, |
| 'fastvit_ma36': { |
| 'pt_model': pt_fastvit_ma36, |
| 'flax_model': flax_fastvit_ma36, |
| 'stage_config': [6, 6, 18, 6], |
| 'embed_dims': [76, 152, 304, 608], |
| 'token_mixers': ['repmixer', 'repmixer', 'repmixer', 'attention'], |
| }, |
| } |
|
|
| def print_stat(name, tensor): |
| if isinstance(tensor, torch.Tensor): |
| tensor = tensor.detach().cpu().numpy() |
| elif hasattr(tensor, '__array__'): |
| tensor = np.array(tensor) |
| |
| print(f"{name: <10} | Mean: {tensor.mean():.6f} | Std: {tensor.std():.6f} | " |
| f"Min: {tensor.min():.6f} | Max: {tensor.max():.6f}") |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Compare PyTorch and Flax models') |
| parser.add_argument('--model', type=str, default='fastvit_t8', choices=list(MODEL_REGISTRY.keys())) |
| parser.add_argument('--threshold', type=float, default=1e-4) |
| parser.add_argument('--load-orbax', action='store_true', help='Load from Orbax checkpoint') |
| |
| args = parser.parse_args() |
| |
| print("="*80) |
| print(f"FastViT: PyTorch vs. Flax Verification - {args.model.upper()}") |
| print(f"Mode: {'Orbax Loading' if args.load_orbax else 'On-the-fly Conversion'}") |
| print("="*80) |
|
|
| model_info = MODEL_REGISTRY[args.model] |
| pt_model_fn = model_info['pt_model'] |
| flax_model_fn = model_info['flax_model'] |
| stage_config = model_info['stage_config'] |
| token_mixers = model_info['token_mixers'] |
|
|
| np.random.seed(42) |
| input_shape = (1, 256, 256, 3) |
| x_np = np.random.normal(0, 1, input_shape).astype(np.float32) |
| x_pt = torch.from_numpy(x_np).permute(0, 3, 1, 2) |
| x_flax = jnp.array(x_np) |
|
|
| print("Loading PyTorch Model (Inference Mode)...") |
| try: |
| model_pt = pt_model_fn(inference_mode=True) |
| except TypeError: |
| model_pt = pt_model_fn() |
| model_pt.eval() |
|
|
| pt_weights_path = f"weights/fused/{args.model}_fused.pth" |
| if not os.path.exists(pt_weights_path): |
| print(f"!!! Fused weights missing: {pt_weights_path}") |
| return |
|
|
| checkpoint = torch.load(pt_weights_path, map_location='cpu') |
| state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint |
|
|
| model_pt.load_state_dict(state_dict) |
| with torch.no_grad(): |
| out_pt = model_pt(x_pt) |
| print_stat("PyTorch", out_pt) |
|
|
| print("-" * 80) |
| print("Initializing Flax Model...") |
| model_fx = flax_model_fn(num_classes=1000) |
| init_variables = model_fx.init(jax.random.PRNGKey(0), x_flax, train=False) |
| |
| if args.load_orbax: |
| print("Loading Orbax Checkpoint...") |
| orbax_dir = f"weights/orbax/{args.model}" |
| abs_path = os.path.abspath(orbax_dir) |
| |
| if not os.path.exists(abs_path): |
| print(f"!!! Orbax path missing: {abs_path}") |
| return |
|
|
| abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, init_variables) |
| |
| with ocp.CheckpointManager(abs_path) as mngr: |
| if mngr.latest_step() is None: |
| print("!!! No checkpoints found.") |
| return |
| |
| restore_args = ocp.args.StandardRestore(abstract_tree) |
| flax_vars = mngr.restore(mngr.latest_step(), args=restore_args) |
| print("✓ Orbax weights loaded") |
| else: |
| print("Converting Weights from PyTorch...") |
| flax_vars = convert_pytorch_to_flax(state_dict, stage_config, token_mixers, verbose=False) |
| |
| out_fx = model_fx.apply(flax_vars, x_flax, train=False) |
| res_pt = out_pt.detach().cpu().numpy() |
| res_fx = np.array(out_fx) |
|
|
| print_stat("Flax", res_fx) |
| |
| diff = np.abs(res_pt - res_fx) |
| max_diff = diff.max() |
| |
| print("-" * 80) |
| print(f"RESULTS: Max Diff = {max_diff:.8f}") |
| if max_diff < args.threshold: |
| print(":) SUCCESS") |
| else: |
| print("!!! FAILURE") |
| print("="*80) |
|
|
| if __name__ == "__main__": |
| main() |
|
|