| |
| |
| |
| |
| |
| |
|
|
| import os |
| import argparse |
| import shutil |
| import jax |
| import torch |
| import orbax.checkpoint as ocp |
| import numpy as np |
|
|
| |
| from weight_conversion import convert_pytorch_to_flax, MODEL_REGISTRY |
|
|
| def save_orbax_checkpoint(params, output_dir, step=0): |
| options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True) |
| |
| abs_path = os.path.abspath(output_dir) |
| |
| with ocp.CheckpointManager(abs_path, options=options) as mngr: |
| print(f" > Saving to: {abs_path}") |
| save_args = ocp.args.StandardSave(params) |
| mngr.save(step, args=save_args) |
| mngr.wait_until_finished() |
|
|
| def verify_checkpoint(output_dir, original_params, step=0): |
| abs_path = os.path.abspath(output_dir) |
| |
| abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, original_params) |
| |
| with ocp.CheckpointManager(abs_path) as mngr: |
| restore_args = ocp.args.StandardRestore(abstract_tree) |
| restored = mngr.restore(step, args=restore_args) |
| |
| leaves_orig, _ = jax.tree_util.tree_flatten(original_params) |
| leaves_rest, _ = jax.tree_util.tree_flatten(restored) |
| |
| if len(leaves_orig) != len(leaves_rest): |
| print(" !!! Verification FAILED: Structure mismatch") |
| return False |
|
|
| diff = np.abs(leaves_orig[0] - leaves_rest[0]).max() |
| |
| if diff < 1e-6: |
| return True |
| else: |
| print(f" !!! Verification FAILED: Data mismatch (diff: {diff:.8f})") |
| return False |
|
|
| def process_model(model_name, clean=False): |
| input_path = f"weights/fused/{model_name}_fused.pth" |
| output_dir = f"weights/orbax/{model_name}" |
|
|
| if not os.path.exists(input_path): |
| return False |
|
|
| print(f"\n[{model_name.upper()}] Processing...") |
| print(f" Input: {input_path}") |
| print(f" Output: {output_dir}") |
|
|
| if clean and os.path.exists(output_dir): |
| print(f" ! Cleaning existing directory: {output_dir}") |
| shutil.rmtree(output_dir) |
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| try: |
| state_dict = torch.load(input_path, map_location="cpu") |
| if "state_dict" in state_dict: |
| state_dict = state_dict["state_dict"] |
| |
| reg_entry = MODEL_REGISTRY[model_name] |
| if isinstance(reg_entry, tuple): |
| _, stage_config, token_mixers = reg_entry |
| else: |
| stage_config = reg_entry['stage_config'] |
| token_mixers = reg_entry['token_mixers'] |
|
|
| flax_vars = convert_pytorch_to_flax(state_dict, stage_config, token_mixers) |
|
|
| save_orbax_checkpoint(flax_vars, output_dir) |
|
|
| if verify_checkpoint(output_dir, flax_vars): |
| print(f" ✓ {model_name} successfully converted to Orbax.") |
| return True |
| else: |
| return False |
|
|
| except Exception as e: |
| print(f" !!! Error processing {model_name}: {e}") |
| return False |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Batch convert fused FastViT weights to Orbax (TPU-friendly)') |
| parser.add_argument('--clean', action='store_true', help='Clean output directories before saving') |
| args = parser.parse_args() |
|
|
| print(f"{'='*80}") |
| print(f"FastViT Orbax Batch Conversion") |
| print(f"Scanning 'weights/fused/' for models in registry...") |
| print(f"{'='*80}") |
|
|
| os.makedirs("weights/orbax", exist_ok=True) |
|
|
| processed_count = 0 |
| skipped_count = 0 |
|
|
| for model_name in MODEL_REGISTRY.keys(): |
| found = process_model(model_name, clean=args.clean) |
| if found: |
| processed_count += 1 |
| else: |
| skipped_count += 1 |
|
|
| print(f"\n{'-'*80}") |
| print(f"Summary:") |
| print(f" Converted: {processed_count}") |
| print(f" Skipped: {skipped_count} (fused weights not found)") |
| print(f"{'='*80}\n") |
|
|
| if __name__ == "__main__": |
| main() |
|
|