fastvit-jax-weights / weight_conversion.py
SilverGrace-26's picture
Initial Commit
b8de3c7
# Copyright (c) 2025 FastViT-JAX Contributors
#
# Licensed under the MIT License. See the LICENSE file for details.
import os
import argparse
import torch
import jax
import jax.numpy as jnp
import numpy as np
from flax_models.fastvit import (
fastvit_t8, fastvit_t12, fastvit_s12,
fastvit_sa12, fastvit_sa24, fastvit_sa36, fastvit_ma36
)
MODEL_REGISTRY = {
'fastvit_t8': (fastvit_t8, [2, 2, 4, 2], ['repmixer'] * 4),
'fastvit_t12': (fastvit_t12, [2, 2, 6, 2], ['repmixer'] * 4),
'fastvit_s12': (fastvit_s12, [2, 2, 6, 2], ['repmixer'] * 4),
'fastvit_sa12': (fastvit_sa12, [2, 2, 6, 2], ['repmixer', 'repmixer', 'repmixer', 'attention']),
'fastvit_sa24': (fastvit_sa24, [4, 4, 12, 4], ['repmixer', 'repmixer', 'repmixer', 'attention']),
'fastvit_sa36': (fastvit_sa36, [6, 6, 18, 6], ['repmixer', 'repmixer', 'repmixer', 'attention']),
'fastvit_ma36': (fastvit_ma36, [6, 6, 18, 6], ['repmixer', 'repmixer', 'repmixer', 'attention']),
}
def transpose_conv_weight(weight):
weight_np = weight.detach().cpu().numpy()
return np.transpose(weight_np, (2, 3, 1, 0))
def transpose_dense_weight(weight):
weight_np = weight.detach().cpu().numpy()
return np.transpose(weight_np, (1, 0))
def reshape_layer_scale(layer_scale):
ls_np = layer_scale.detach().cpu().numpy()
return np.transpose(ls_np, (1, 2, 0))
def convert_pytorch_to_flax(pytorch_state_dict, stage_config, token_mixers, verbose=False):
flax_params = {}
flax_batch_stats = {}
processed_keys = set()
def get_val(key):
if key not in pytorch_state_dict:
raise KeyError(f"Missing key in PyTorch dict: {key}")
processed_keys.add(key)
return pytorch_state_dict[key].detach().cpu().numpy()
def load_conv(pt_prefix):
w = pytorch_state_dict[f"{pt_prefix}.weight"]
processed_keys.add(f"{pt_prefix}.weight")
return {
'kernel': transpose_conv_weight(w),
'bias': get_val(f"{pt_prefix}.bias")
}
def load_dense(pt_prefix):
w = pytorch_state_dict[f"{pt_prefix}.weight"]
processed_keys.add(f"{pt_prefix}.weight")
return {
'kernel': transpose_dense_weight(w),
'bias': get_val(f"{pt_prefix}.bias")
}
def load_bn(pt_prefix):
return {
'params': {
'scale': get_val(f"{pt_prefix}.weight"),
'bias': get_val(f"{pt_prefix}.bias")
},
'stats': {
'mean': get_val(f"{pt_prefix}.running_mean"),
'var': get_val(f"{pt_prefix}.running_var")
}
}
def mark_bn_tracked(pt_prefix):
k = f"{pt_prefix}.num_batches_tracked"
if k in pytorch_state_dict: processed_keys.add(k)
if verbose: print("Converting PyTorch weights to Flax format...")
# 1. Stem
for i in range(3):
flax_params[f'stem_{i}'] = {'reparam_conv': load_conv(f"patch_embed.{i}.reparam_conv")}
# 2. Stages
network_idx = 0
for stage_idx, num_blocks in enumerate(stage_config):
mixer_type = token_mixers[stage_idx]
# Positional Embedding (Only for attention stages)
if mixer_type == 'attention':
pt_prefix = f"network.{network_idx}.reparam_conv"
if f"{pt_prefix}.weight" in pytorch_state_dict:
flax_params[f'pos_embed_{stage_idx}'] = {'reparam_conv': load_conv(pt_prefix)}
network_idx += 1
# Blocks
for block_idx in range(num_blocks):
pt_block = f"network.{network_idx}.{block_idx}"
flax_block = f"stages_{stage_idx}_{block_idx}"
block_p = {}
block_s = {}
# Token Mixer
if mixer_type == 'repmixer':
block_p['token_mixer'] = {'reparam_conv': load_conv(f"{pt_block}.token_mixer.reparam_conv")}
# Layer Scale 1 (RepMixer specific)
ls_key = f"{pt_block}.layer_scale"
block_p['layer_scale'] = reshape_layer_scale(pytorch_state_dict[ls_key])
processed_keys.add(ls_key)
elif mixer_type == 'attention':
# Norm
bn_data = load_bn(f"{pt_block}.norm")
block_p['norm'] = bn_data['params']
block_s['norm'] = bn_data['stats']
mark_bn_tracked(f"{pt_block}.norm")
# Attention (QKV + Proj)
w_qkv = pytorch_state_dict[f"{pt_block}.token_mixer.qkv.weight"]
processed_keys.add(f"{pt_block}.token_mixer.qkv.weight")
block_p['token_mixer'] = {
'qkv': {'kernel': transpose_dense_weight(w_qkv)},
'proj': load_dense(f"{pt_block}.token_mixer.proj")
}
# Layer Scale 1 (Attention specific)
ls_key = f"{pt_block}.layer_scale_1"
block_p['layer_scale_1'] = reshape_layer_scale(pytorch_state_dict[ls_key])
processed_keys.add(ls_key)
# ConvFFN
block_p['convffn'] = {
'conv_conv': {'kernel': transpose_conv_weight(pytorch_state_dict[f"{pt_block}.convffn.conv.conv.weight"])},
'fc1': load_conv(f"{pt_block}.convffn.fc1"),
'fc2': load_conv(f"{pt_block}.convffn.fc2")
}
processed_keys.add(f"{pt_block}.convffn.conv.conv.weight")
# ConvFFN BatchNorm
bn_data = load_bn(f"{pt_block}.convffn.conv.bn")
block_p['convffn']['conv_bn'] = bn_data['params']
block_s['convffn'] = {'conv_bn': bn_data['stats']}
mark_bn_tracked(f"{pt_block}.convffn.conv.bn")
# Layer Scale 2
if mixer_type == 'attention':
ls2_key = f"{pt_block}.layer_scale_2"
block_p['layer_scale_2'] = reshape_layer_scale(pytorch_state_dict[ls2_key])
processed_keys.add(ls2_key)
flax_params[flax_block] = block_p
if block_s: flax_batch_stats[flax_block] = block_s
network_idx += 1
# Downsampling (Proj)
if stage_idx < len(stage_config) - 1:
pt_ds = f"network.{network_idx}.proj"
flax_params[f"downsample_layers_{stage_idx}"] = {
'proj_0': {'lkb_reparam': load_conv(f"{pt_ds}.0.lkb_reparam")},
'proj_1': {'reparam_conv': load_conv(f"{pt_ds}.1.reparam_conv")}
}
network_idx += 1
# 3. Head
flax_params['conv_exp'] = {
'reparam_conv': load_conv("conv_exp.reparam_conv"),
'se': {
'reduce': load_conv("conv_exp.se.reduce"),
'expand': load_conv("conv_exp.se.expand")
}
}
flax_params['head'] = load_dense("head")
# 4. Verification
all_pt_keys = set(pytorch_state_dict.keys())
relevant_keys = {k for k in all_pt_keys if "num_batches_tracked" not in k}
missing_keys = relevant_keys - processed_keys
if missing_keys:
raise RuntimeError(f"Missing keys in conversion: {sorted(missing_keys)}")
if verbose: print("✓ Conversion successful!")
return {'params': flax_params, 'batch_stats': flax_batch_stats}
def load_fused_weights(pytorch_weights_path, stage_config, token_mixers, verbose=False):
if verbose: print(f"Loading weights from {pytorch_weights_path}")
if not os.path.exists(pytorch_weights_path):
raise FileNotFoundError(f"Weights not found: {pytorch_weights_path}")
pytorch_state_dict = torch.load(pytorch_weights_path, map_location='cpu')
return convert_pytorch_to_flax(pytorch_state_dict, stage_config, token_mixers, verbose=verbose)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='fastvit_t8', choices=list(MODEL_REGISTRY.keys()))
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
model_fn, stage_config, token_mixers = MODEL_REGISTRY[args.model]
model = model_fn(num_classes=1000)
weights_path = f"weights/{args.model}_fused.pth"
params = load_fused_weights(model, weights_path, stage_config, token_mixers, verbose=args.verbose)
print("Testing forward pass...")
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (1, 256, 256, 3))
output = model.apply(params, x, train=False)
probs = jax.nn.softmax(output, axis=-1)
top5 = jnp.argsort(probs[0])[-5:][::-1]
print(f"Top prediction: Class {top5[0]} ({probs[0, top5[0]]:.4f})")
if __name__ == "__main__":
main()