fastvit-jax-weights / save_weights.py
SilverGrace-26's picture
Finishing
d18100d
# Copyright (c) 2025 FastViT-JAX Contributors
#
# Utility script for converting FastViT weights to Orbax format.
# This format is TPU-friendly and compatible with modern JAX distributed workloads.
#
# Licensed under the MIT License. See the LICENSE file for details.
import os
import argparse
import shutil
import jax
import torch
import orbax.checkpoint as ocp
import numpy as np
# Import from existing modules
from weight_conversion import convert_pytorch_to_flax, MODEL_REGISTRY
def save_orbax_checkpoint(params, output_dir, step=0):
options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True)
abs_path = os.path.abspath(output_dir)
with ocp.CheckpointManager(abs_path, options=options) as mngr:
print(f" > Saving to: {abs_path}")
save_args = ocp.args.StandardSave(params)
mngr.save(step, args=save_args)
mngr.wait_until_finished()
def verify_checkpoint(output_dir, original_params, step=0):
abs_path = os.path.abspath(output_dir)
abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, original_params)
with ocp.CheckpointManager(abs_path) as mngr:
restore_args = ocp.args.StandardRestore(abstract_tree)
restored = mngr.restore(step, args=restore_args)
leaves_orig, _ = jax.tree_util.tree_flatten(original_params)
leaves_rest, _ = jax.tree_util.tree_flatten(restored)
if len(leaves_orig) != len(leaves_rest):
print(" !!! Verification FAILED: Structure mismatch")
return False
diff = np.abs(leaves_orig[0] - leaves_rest[0]).max()
if diff < 1e-6:
return True
else:
print(f" !!! Verification FAILED: Data mismatch (diff: {diff:.8f})")
return False
def process_model(model_name, clean=False):
input_path = f"weights/fused/{model_name}_fused.pth"
output_dir = f"weights/orbax/{model_name}"
if not os.path.exists(input_path):
return False
print(f"\n[{model_name.upper()}] Processing...")
print(f" Input: {input_path}")
print(f" Output: {output_dir}")
if clean and os.path.exists(output_dir):
print(f" ! Cleaning existing directory: {output_dir}")
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
try:
state_dict = torch.load(input_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
reg_entry = MODEL_REGISTRY[model_name]
if isinstance(reg_entry, tuple):
_, stage_config, token_mixers = reg_entry
else:
stage_config = reg_entry['stage_config']
token_mixers = reg_entry['token_mixers']
flax_vars = convert_pytorch_to_flax(state_dict, stage_config, token_mixers)
save_orbax_checkpoint(flax_vars, output_dir)
if verify_checkpoint(output_dir, flax_vars):
print(f" ✓ {model_name} successfully converted to Orbax.")
return True
else:
return False
except Exception as e:
print(f" !!! Error processing {model_name}: {e}")
return False
def main():
parser = argparse.ArgumentParser(description='Batch convert fused FastViT weights to Orbax (TPU-friendly)')
parser.add_argument('--clean', action='store_true', help='Clean output directories before saving')
args = parser.parse_args()
print(f"{'='*80}")
print(f"FastViT Orbax Batch Conversion")
print(f"Scanning 'weights/fused/' for models in registry...")
print(f"{'='*80}")
os.makedirs("weights/orbax", exist_ok=True)
processed_count = 0
skipped_count = 0
for model_name in MODEL_REGISTRY.keys():
found = process_model(model_name, clean=args.clean)
if found:
processed_count += 1
else:
skipped_count += 1
print(f"\n{'-'*80}")
print(f"Summary:")
print(f" Converted: {processed_count}")
print(f" Skipped: {skipped_count} (fused weights not found)")
print(f"{'='*80}\n")
if __name__ == "__main__":
main()