| |
| |
| |
|
|
| import os |
| import argparse |
| import torch |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from flax_models.fastvit import ( |
| fastvit_t8, fastvit_t12, fastvit_s12, |
| fastvit_sa12, fastvit_sa24, fastvit_sa36, fastvit_ma36 |
| ) |
|
|
| MODEL_REGISTRY = { |
| 'fastvit_t8': (fastvit_t8, [2, 2, 4, 2], ['repmixer'] * 4), |
| 'fastvit_t12': (fastvit_t12, [2, 2, 6, 2], ['repmixer'] * 4), |
| 'fastvit_s12': (fastvit_s12, [2, 2, 6, 2], ['repmixer'] * 4), |
| 'fastvit_sa12': (fastvit_sa12, [2, 2, 6, 2], ['repmixer', 'repmixer', 'repmixer', 'attention']), |
| 'fastvit_sa24': (fastvit_sa24, [4, 4, 12, 4], ['repmixer', 'repmixer', 'repmixer', 'attention']), |
| 'fastvit_sa36': (fastvit_sa36, [6, 6, 18, 6], ['repmixer', 'repmixer', 'repmixer', 'attention']), |
| 'fastvit_ma36': (fastvit_ma36, [6, 6, 18, 6], ['repmixer', 'repmixer', 'repmixer', 'attention']), |
| } |
|
|
|
|
| def transpose_conv_weight(weight): |
| weight_np = weight.detach().cpu().numpy() |
| return np.transpose(weight_np, (2, 3, 1, 0)) |
|
|
|
|
| def transpose_dense_weight(weight): |
| weight_np = weight.detach().cpu().numpy() |
| return np.transpose(weight_np, (1, 0)) |
|
|
|
|
| def reshape_layer_scale(layer_scale): |
| ls_np = layer_scale.detach().cpu().numpy() |
| return np.transpose(ls_np, (1, 2, 0)) |
|
|
|
|
| def convert_pytorch_to_flax(pytorch_state_dict, stage_config, token_mixers, verbose=False): |
| flax_params = {} |
| flax_batch_stats = {} |
| processed_keys = set() |
|
|
| def get_val(key): |
| if key not in pytorch_state_dict: |
| raise KeyError(f"Missing key in PyTorch dict: {key}") |
| processed_keys.add(key) |
| return pytorch_state_dict[key].detach().cpu().numpy() |
|
|
| def load_conv(pt_prefix): |
| w = pytorch_state_dict[f"{pt_prefix}.weight"] |
| processed_keys.add(f"{pt_prefix}.weight") |
| return { |
| 'kernel': transpose_conv_weight(w), |
| 'bias': get_val(f"{pt_prefix}.bias") |
| } |
|
|
| def load_dense(pt_prefix): |
| w = pytorch_state_dict[f"{pt_prefix}.weight"] |
| processed_keys.add(f"{pt_prefix}.weight") |
| return { |
| 'kernel': transpose_dense_weight(w), |
| 'bias': get_val(f"{pt_prefix}.bias") |
| } |
|
|
| def load_bn(pt_prefix): |
| return { |
| 'params': { |
| 'scale': get_val(f"{pt_prefix}.weight"), |
| 'bias': get_val(f"{pt_prefix}.bias") |
| }, |
| 'stats': { |
| 'mean': get_val(f"{pt_prefix}.running_mean"), |
| 'var': get_val(f"{pt_prefix}.running_var") |
| } |
| } |
| |
| def mark_bn_tracked(pt_prefix): |
| k = f"{pt_prefix}.num_batches_tracked" |
| if k in pytorch_state_dict: processed_keys.add(k) |
|
|
| if verbose: print("Converting PyTorch weights to Flax format...") |
|
|
| |
| for i in range(3): |
| flax_params[f'stem_{i}'] = {'reparam_conv': load_conv(f"patch_embed.{i}.reparam_conv")} |
|
|
| |
| network_idx = 0 |
| for stage_idx, num_blocks in enumerate(stage_config): |
| mixer_type = token_mixers[stage_idx] |
|
|
| |
| if mixer_type == 'attention': |
| pt_prefix = f"network.{network_idx}.reparam_conv" |
| if f"{pt_prefix}.weight" in pytorch_state_dict: |
| flax_params[f'pos_embed_{stage_idx}'] = {'reparam_conv': load_conv(pt_prefix)} |
| network_idx += 1 |
|
|
| |
| for block_idx in range(num_blocks): |
| pt_block = f"network.{network_idx}.{block_idx}" |
| flax_block = f"stages_{stage_idx}_{block_idx}" |
| |
| block_p = {} |
| block_s = {} |
|
|
| |
| if mixer_type == 'repmixer': |
| block_p['token_mixer'] = {'reparam_conv': load_conv(f"{pt_block}.token_mixer.reparam_conv")} |
| |
| ls_key = f"{pt_block}.layer_scale" |
| block_p['layer_scale'] = reshape_layer_scale(pytorch_state_dict[ls_key]) |
| processed_keys.add(ls_key) |
| |
| elif mixer_type == 'attention': |
| |
| bn_data = load_bn(f"{pt_block}.norm") |
| block_p['norm'] = bn_data['params'] |
| block_s['norm'] = bn_data['stats'] |
| mark_bn_tracked(f"{pt_block}.norm") |
|
|
| |
| w_qkv = pytorch_state_dict[f"{pt_block}.token_mixer.qkv.weight"] |
| processed_keys.add(f"{pt_block}.token_mixer.qkv.weight") |
| |
| block_p['token_mixer'] = { |
| 'qkv': {'kernel': transpose_dense_weight(w_qkv)}, |
| 'proj': load_dense(f"{pt_block}.token_mixer.proj") |
| } |
| |
| |
| ls_key = f"{pt_block}.layer_scale_1" |
| block_p['layer_scale_1'] = reshape_layer_scale(pytorch_state_dict[ls_key]) |
| processed_keys.add(ls_key) |
|
|
| |
| block_p['convffn'] = { |
| 'conv_conv': {'kernel': transpose_conv_weight(pytorch_state_dict[f"{pt_block}.convffn.conv.conv.weight"])}, |
| 'fc1': load_conv(f"{pt_block}.convffn.fc1"), |
| 'fc2': load_conv(f"{pt_block}.convffn.fc2") |
| } |
| processed_keys.add(f"{pt_block}.convffn.conv.conv.weight") |
|
|
| |
| bn_data = load_bn(f"{pt_block}.convffn.conv.bn") |
| block_p['convffn']['conv_bn'] = bn_data['params'] |
| block_s['convffn'] = {'conv_bn': bn_data['stats']} |
| mark_bn_tracked(f"{pt_block}.convffn.conv.bn") |
|
|
| |
| if mixer_type == 'attention': |
| ls2_key = f"{pt_block}.layer_scale_2" |
| block_p['layer_scale_2'] = reshape_layer_scale(pytorch_state_dict[ls2_key]) |
| processed_keys.add(ls2_key) |
|
|
| flax_params[flax_block] = block_p |
| if block_s: flax_batch_stats[flax_block] = block_s |
|
|
| network_idx += 1 |
|
|
| |
| if stage_idx < len(stage_config) - 1: |
| pt_ds = f"network.{network_idx}.proj" |
| flax_params[f"downsample_layers_{stage_idx}"] = { |
| 'proj_0': {'lkb_reparam': load_conv(f"{pt_ds}.0.lkb_reparam")}, |
| 'proj_1': {'reparam_conv': load_conv(f"{pt_ds}.1.reparam_conv")} |
| } |
| network_idx += 1 |
|
|
| |
| flax_params['conv_exp'] = { |
| 'reparam_conv': load_conv("conv_exp.reparam_conv"), |
| 'se': { |
| 'reduce': load_conv("conv_exp.se.reduce"), |
| 'expand': load_conv("conv_exp.se.expand") |
| } |
| } |
| flax_params['head'] = load_dense("head") |
|
|
| |
| all_pt_keys = set(pytorch_state_dict.keys()) |
| relevant_keys = {k for k in all_pt_keys if "num_batches_tracked" not in k} |
| missing_keys = relevant_keys - processed_keys |
| |
| if missing_keys: |
| raise RuntimeError(f"Missing keys in conversion: {sorted(missing_keys)}") |
|
|
| if verbose: print("✓ Conversion successful!") |
| return {'params': flax_params, 'batch_stats': flax_batch_stats} |
|
|
|
|
| def load_fused_weights(pytorch_weights_path, stage_config, token_mixers, verbose=False): |
| if verbose: print(f"Loading weights from {pytorch_weights_path}") |
| if not os.path.exists(pytorch_weights_path): |
| raise FileNotFoundError(f"Weights not found: {pytorch_weights_path}") |
| |
| pytorch_state_dict = torch.load(pytorch_weights_path, map_location='cpu') |
| return convert_pytorch_to_flax(pytorch_state_dict, stage_config, token_mixers, verbose=verbose) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model', type=str, default='fastvit_t8', choices=list(MODEL_REGISTRY.keys())) |
| parser.add_argument('--verbose', action='store_true') |
| args = parser.parse_args() |
| |
| model_fn, stage_config, token_mixers = MODEL_REGISTRY[args.model] |
| |
| model = model_fn(num_classes=1000) |
| weights_path = f"weights/{args.model}_fused.pth" |
| |
| params = load_fused_weights(model, weights_path, stage_config, token_mixers, verbose=args.verbose) |
| |
| print("Testing forward pass...") |
| key = jax.random.PRNGKey(42) |
| x = jax.random.normal(key, (1, 256, 256, 3)) |
| output = model.apply(params, x, train=False) |
| |
| probs = jax.nn.softmax(output, axis=-1) |
| top5 = jnp.argsort(probs[0])[-5:][::-1] |
| |
| print(f"Top prediction: Class {top5[0]} ({probs[0, top5[0]]:.4f})") |
|
|
| if __name__ == "__main__": |
| main() |
|
|