| |
| """ |
| Convert Thera RDN weights (air/pro) from Flax pickle to MLX safetensors format. |
| |
| Requires: jax, flax, numpy, safetensors, huggingface_hub |
| These are only needed for conversion, not for inference. |
| """ |
|
|
| import argparse |
| import os |
| import pickle |
| import sys |
|
|
| import numpy as np |
|
|
|
|
| def conv_weight(kernel): |
| """Flax Conv (H, W, C_in, C_out) → MLX Conv2d (C_out, H, W, C_in)""" |
| return np.transpose(np.asarray(kernel), (3, 0, 1, 2)) |
|
|
|
|
| def dense_weight(kernel): |
| """Flax Dense (in, out) → MLX Linear (out, in)""" |
| return np.transpose(np.asarray(kernel), (1, 0)) |
|
|
|
|
| def layernorm_params(flax_ln, prefix): |
| """Map Flax LayerNorm scale/bias to MLX weight/bias.""" |
| w = {} |
| w[f'{prefix}.weight'] = np.asarray(flax_ln['scale']) |
| if 'bias' in flax_ln: |
| w[f'{prefix}.bias'] = np.asarray(flax_ln['bias']) |
| return w |
|
|
|
|
| def convert_rdn_encoder(enc): |
| """Convert RDN backbone weights.""" |
| weights = {} |
|
|
| |
| weights['encoder.sfe1.weight'] = conv_weight(enc['Conv_0']['kernel']) |
| weights['encoder.sfe1.bias'] = np.asarray(enc['Conv_0']['bias']) |
| weights['encoder.sfe2.weight'] = conv_weight(enc['Conv_1']['kernel']) |
| weights['encoder.sfe2.bias'] = np.asarray(enc['Conv_1']['bias']) |
|
|
| |
| for i in range(16): |
| rdb = enc[f'RDB_{i}'] |
| |
| for j in range(8): |
| rc = rdb[f'RDB_Conv_{j}']['Conv_0'] |
| prefix = f'encoder.rdbs.{i}.convs.{j}.conv' |
| weights[f'{prefix}.weight'] = conv_weight(rc['kernel']) |
| weights[f'{prefix}.bias'] = np.asarray(rc['bias']) |
| |
| lf = rdb['Conv_0'] |
| prefix = f'encoder.rdbs.{i}.local_fusion' |
| weights[f'{prefix}.weight'] = conv_weight(lf['kernel']) |
| weights[f'{prefix}.bias'] = np.asarray(lf['bias']) |
|
|
| |
| weights['encoder.gff_1x1.weight'] = conv_weight(enc['Conv_2']['kernel']) |
| weights['encoder.gff_1x1.bias'] = np.asarray(enc['Conv_2']['bias']) |
| weights['encoder.gff_3x3.weight'] = conv_weight(enc['Conv_3']['kernel']) |
| weights['encoder.gff_3x3.bias'] = np.asarray(enc['Conv_3']['bias']) |
|
|
| return weights |
|
|
|
|
| def convert_swinir_tail(ref): |
| """Convert SwinIR tail (refine) weights for rdn-pro.""" |
| weights = {} |
|
|
| |
| weights['refine.conv_first.weight'] = conv_weight(ref['Conv_0']['kernel']) |
| weights['refine.conv_first.bias'] = np.asarray(ref['Conv_0']['bias']) |
|
|
| |
| weights.update(layernorm_params( |
| ref['PatchEmbed_0']['LayerNorm_0'], 'refine.patch_embed_norm')) |
|
|
| |
| rstb_depths = [7, 6] |
| for i, depth in enumerate(rstb_depths): |
| rstb = ref[f'RSTB_{i}'] |
| basic = rstb['BasicLayer_0'] |
|
|
| for j in range(depth): |
| stb = basic[f'SwinTransformerBlock_{j}'] |
| mlx_prefix = f'refine.layers.{i}.blocks.{j}' |
|
|
| |
| weights.update(layernorm_params( |
| stb['LayerNorm_0'], f'{mlx_prefix}.norm1')) |
|
|
| |
| wa = stb['WindowAttention_0'] |
| |
| weights[f'{mlx_prefix}.attn.qkv.weight'] = dense_weight(wa['qkv']['kernel']) |
| weights[f'{mlx_prefix}.attn.qkv.bias'] = np.asarray(wa['qkv']['bias']) |
| |
| weights[f'{mlx_prefix}.attn.proj.weight'] = dense_weight(wa['proj']['kernel']) |
| weights[f'{mlx_prefix}.attn.proj.bias'] = np.asarray(wa['proj']['bias']) |
| |
| weights[f'{mlx_prefix}.attn.relative_position_bias_table'] = \ |
| np.asarray(wa['relative_position_bias_table']) |
|
|
| |
| weights.update(layernorm_params( |
| stb['LayerNorm_1'], f'{mlx_prefix}.norm2')) |
|
|
| |
| mlp = stb['Mlp_0'] |
| weights[f'{mlx_prefix}.mlp.fc1.weight'] = dense_weight(mlp['Dense_0']['kernel']) |
| weights[f'{mlx_prefix}.mlp.fc1.bias'] = np.asarray(mlp['Dense_0']['bias']) |
| weights[f'{mlx_prefix}.mlp.fc2.weight'] = dense_weight(mlp['Dense_1']['kernel']) |
| weights[f'{mlx_prefix}.mlp.fc2.bias'] = np.asarray(mlp['Dense_1']['bias']) |
|
|
| |
| weights[f'refine.layers.{i}.conv.weight'] = conv_weight(rstb['Conv_0']['kernel']) |
| weights[f'refine.layers.{i}.conv.bias'] = np.asarray(rstb['Conv_0']['bias']) |
|
|
| |
| weights.update(layernorm_params(ref['LayerNorm_0'], 'refine.norm')) |
|
|
| |
| weights['refine.conv_after_body.weight'] = conv_weight(ref['Conv_1']['kernel']) |
| weights['refine.conv_after_body.bias'] = np.asarray(ref['Conv_1']['bias']) |
|
|
| |
| weights['refine.conv_last.weight'] = conv_weight(ref['Conv_2']['kernel']) |
| weights['refine.conv_last.bias'] = np.asarray(ref['Conv_2']['bias']) |
|
|
| return weights |
|
|
|
|
| def convert_flax_to_mlx(flax_params, size='air'): |
| """Map Flax parameter tree to flat MLX weight dict.""" |
| p = flax_params['params'] |
| weights = {} |
|
|
| |
| weights['k'] = np.asarray(p['k'], dtype=np.float32).reshape(()) |
| weights['components'] = np.asarray(p['components'], dtype=np.float32) |
|
|
| |
| weights.update(convert_rdn_encoder(p['encoder'])) |
|
|
| |
| if size == 'pro': |
| weights.update(convert_swinir_tail(p['refine'])) |
|
|
| |
| weights['out_conv.weight'] = conv_weight(p['out_conv']['kernel']) |
| weights['out_conv.bias'] = np.asarray(p['out_conv']['bias']) |
|
|
| return weights |
|
|
|
|
| REPO_IDS = { |
| 'air': 'prs-eth/thera-rdn-air', |
| 'pro': 'prs-eth/thera-rdn-pro', |
| } |
|
|
|
|
| def download_model(size='air', filename="model.pkl", cache_dir=None): |
| """Download model pickle from HuggingFace.""" |
| from huggingface_hub import hf_hub_download |
| repo_id = REPO_IDS[size] |
| return hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) |
|
|
|
|
| def load_pickle_with_jax(path): |
| """Load Flax pickle (requires jax and flax installed).""" |
| try: |
| import jax |
| import flax |
| except ImportError: |
| print("Error: jax and flax are required for loading the original weights.") |
| print("Install them with: pip install jax flax") |
| sys.exit(1) |
|
|
| with open(path, 'rb') as f: |
| checkpoint = pickle.load(f) |
|
|
| params = checkpoint['model'] |
| backbone = checkpoint['backbone'] |
| size = checkpoint['size'] |
| print(f"Loaded checkpoint: backbone={backbone}, size={size}") |
|
|
| if backbone != 'rdn': |
| print(f"Warning: this converter is designed for rdn, got {backbone}") |
|
|
| return params |
|
|
|
|
| def load_pickle_without_jax(path): |
| """ |
| Attempt to load Flax pickle without JAX by mocking the required classes. |
| Falls back to JAX-based loading if this fails. |
| """ |
| import types |
|
|
| class MockFrozenDict(dict): |
| pass |
|
|
| class MockModule(types.ModuleType): |
| def __getattr__(self, name): |
| return MockModule(name) |
|
|
| class NumpyUnpickler(pickle.Unpickler): |
| def find_class(self, module, name): |
| if 'frozen_dict' in module and name == 'FrozenDict': |
| return MockFrozenDict |
| if module.startswith('jax') and name == '_reconstruct_array': |
| |
| def reconstruct(*args): |
| |
| if len(args) >= 1 and isinstance(args[0], np.ndarray): |
| return args[0] |
| return np.array(args[0]) |
| return reconstruct |
| if module.startswith('jax'): |
| try: |
| return super().find_class(module, name) |
| except (ImportError, AttributeError): |
| return lambda *a, **kw: a[0] if a else None |
| return super().find_class(module, name) |
|
|
| try: |
| with open(path, 'rb') as f: |
| checkpoint = NumpyUnpickler(f).load() |
| params = checkpoint['model'] |
| backbone = checkpoint['backbone'] |
| size = checkpoint['size'] |
| print(f"Loaded checkpoint (no-jax mode): backbone={backbone}, size={size}") |
| return params |
| except Exception as e: |
| print(f"Mock unpickle failed ({e}), falling back to JAX-based loading...") |
| return load_pickle_with_jax(path) |
|
|
|
|
| def save_safetensors(weights, output_path): |
| """Save weight dict as safetensors.""" |
| from safetensors.numpy import save_file |
| save_file(weights, output_path) |
| print(f"Saved MLX weights to {output_path}") |
|
|
|
|
| def save_npz(weights, output_path): |
| """Save weight dict as npz (fallback if safetensors not available).""" |
| np.savez(output_path, **weights) |
| print(f"Saved MLX weights to {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert Thera RDN weights to MLX format") |
| parser.add_argument('--model', type=str, choices=['air', 'pro'], default='air', |
| help='Model variant (default: air)') |
| parser.add_argument('--input', type=str, default=None, |
| help='Path to model.pkl (downloads from HuggingFace if not provided)') |
| parser.add_argument('--output', type=str, default=None, |
| help='Output path (default: weights-{model}.safetensors)') |
| parser.add_argument('--no-jax', action='store_true', |
| help='Try to load pickle without JAX installed') |
| args = parser.parse_args() |
|
|
| if args.output is None: |
| args.output = f'weights-{args.model}.safetensors' |
|
|
| |
| if args.input is None: |
| repo = REPO_IDS[args.model] |
| print(f"Downloading model from HuggingFace ({repo})...") |
| pkl_path = download_model(args.model) |
| else: |
| pkl_path = args.input |
|
|
| |
| if args.no_jax: |
| flax_params = load_pickle_without_jax(pkl_path) |
| else: |
| flax_params = load_pickle_with_jax(pkl_path) |
|
|
| |
| print("Converting weights...") |
| mlx_weights = convert_flax_to_mlx(flax_params, size=args.model) |
|
|
| |
| total_params = sum(w.size for w in mlx_weights.values()) |
| print(f"Total parameters: {total_params:,}") |
| print(f"Weight entries: {len(mlx_weights)}") |
|
|
| |
| output_path = args.output |
| if output_path.endswith('.safetensors'): |
| try: |
| save_safetensors(mlx_weights, output_path) |
| except ImportError: |
| output_path = output_path.replace('.safetensors', '.npz') |
| print("safetensors not installed, saving as npz instead") |
| save_npz(mlx_weights, output_path) |
| else: |
| save_npz(mlx_weights, output_path) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|