#!/usr/bin/env python3 """ Convert Thera RDN weights (air/pro) from Flax pickle to MLX safetensors format. Requires: jax, flax, numpy, safetensors, huggingface_hub These are only needed for conversion, not for inference. """ import argparse import os import pickle import sys import numpy as np def conv_weight(kernel): """Flax Conv (H, W, C_in, C_out) → MLX Conv2d (C_out, H, W, C_in)""" return np.transpose(np.asarray(kernel), (3, 0, 1, 2)) def dense_weight(kernel): """Flax Dense (in, out) → MLX Linear (out, in)""" return np.transpose(np.asarray(kernel), (1, 0)) def layernorm_params(flax_ln, prefix): """Map Flax LayerNorm scale/bias to MLX weight/bias.""" w = {} w[f'{prefix}.weight'] = np.asarray(flax_ln['scale']) if 'bias' in flax_ln: w[f'{prefix}.bias'] = np.asarray(flax_ln['bias']) return w def convert_rdn_encoder(enc): """Convert RDN backbone weights.""" weights = {} # SFE1 (Conv_0) and SFE2 (Conv_1) weights['encoder.sfe1.weight'] = conv_weight(enc['Conv_0']['kernel']) weights['encoder.sfe1.bias'] = np.asarray(enc['Conv_0']['bias']) weights['encoder.sfe2.weight'] = conv_weight(enc['Conv_1']['kernel']) weights['encoder.sfe2.bias'] = np.asarray(enc['Conv_1']['bias']) # 16 Residual Dense Blocks for i in range(16): rdb = enc[f'RDB_{i}'] # 8 RDB_Conv layers per block for j in range(8): rc = rdb[f'RDB_Conv_{j}']['Conv_0'] prefix = f'encoder.rdbs.{i}.convs.{j}.conv' weights[f'{prefix}.weight'] = conv_weight(rc['kernel']) weights[f'{prefix}.bias'] = np.asarray(rc['bias']) # Local fusion 1x1 conv (Conv_0 at RDB level) lf = rdb['Conv_0'] prefix = f'encoder.rdbs.{i}.local_fusion' weights[f'{prefix}.weight'] = conv_weight(lf['kernel']) weights[f'{prefix}.bias'] = np.asarray(lf['bias']) # Global Feature Fusion (Conv_2 = 1x1, Conv_3 = 3x3) weights['encoder.gff_1x1.weight'] = conv_weight(enc['Conv_2']['kernel']) weights['encoder.gff_1x1.bias'] = np.asarray(enc['Conv_2']['bias']) weights['encoder.gff_3x3.weight'] = conv_weight(enc['Conv_3']['kernel']) weights['encoder.gff_3x3.bias'] = np.asarray(enc['Conv_3']['bias']) return weights def convert_swinir_tail(ref): """Convert SwinIR tail (refine) weights for rdn-pro.""" weights = {} # conv_first: refine/Conv_0 weights['refine.conv_first.weight'] = conv_weight(ref['Conv_0']['kernel']) weights['refine.conv_first.bias'] = np.asarray(ref['Conv_0']['bias']) # patch_embed_norm: refine/PatchEmbed_0/LayerNorm_0 weights.update(layernorm_params( ref['PatchEmbed_0']['LayerNorm_0'], 'refine.patch_embed_norm')) # RSTB layers rstb_depths = [7, 6] # number of SwinTransformerBlocks per RSTB for i, depth in enumerate(rstb_depths): rstb = ref[f'RSTB_{i}'] basic = rstb['BasicLayer_0'] for j in range(depth): stb = basic[f'SwinTransformerBlock_{j}'] mlx_prefix = f'refine.layers.{i}.blocks.{j}' # LayerNorm_0 → norm1 weights.update(layernorm_params( stb['LayerNorm_0'], f'{mlx_prefix}.norm1')) # WindowAttention_0 wa = stb['WindowAttention_0'] # qkv Dense → Linear weights[f'{mlx_prefix}.attn.qkv.weight'] = dense_weight(wa['qkv']['kernel']) weights[f'{mlx_prefix}.attn.qkv.bias'] = np.asarray(wa['qkv']['bias']) # proj Dense → Linear weights[f'{mlx_prefix}.attn.proj.weight'] = dense_weight(wa['proj']['kernel']) weights[f'{mlx_prefix}.attn.proj.bias'] = np.asarray(wa['proj']['bias']) # relative_position_bias_table (no transform needed) weights[f'{mlx_prefix}.attn.relative_position_bias_table'] = \ np.asarray(wa['relative_position_bias_table']) # LayerNorm_1 → norm2 weights.update(layernorm_params( stb['LayerNorm_1'], f'{mlx_prefix}.norm2')) # Mlp_0 → mlp mlp = stb['Mlp_0'] weights[f'{mlx_prefix}.mlp.fc1.weight'] = dense_weight(mlp['Dense_0']['kernel']) weights[f'{mlx_prefix}.mlp.fc1.bias'] = np.asarray(mlp['Dense_0']['bias']) weights[f'{mlx_prefix}.mlp.fc2.weight'] = dense_weight(mlp['Dense_1']['kernel']) weights[f'{mlx_prefix}.mlp.fc2.bias'] = np.asarray(mlp['Dense_1']['bias']) # RSTB conv: RSTB_{i}/Conv_0 weights[f'refine.layers.{i}.conv.weight'] = conv_weight(rstb['Conv_0']['kernel']) weights[f'refine.layers.{i}.conv.bias'] = np.asarray(rstb['Conv_0']['bias']) # Final norm: refine/LayerNorm_0 weights.update(layernorm_params(ref['LayerNorm_0'], 'refine.norm')) # conv_after_body: refine/Conv_1 weights['refine.conv_after_body.weight'] = conv_weight(ref['Conv_1']['kernel']) weights['refine.conv_after_body.bias'] = np.asarray(ref['Conv_1']['bias']) # conv_last: refine/Conv_2 weights['refine.conv_last.weight'] = conv_weight(ref['Conv_2']['kernel']) weights['refine.conv_last.bias'] = np.asarray(ref['Conv_2']['bias']) return weights def convert_flax_to_mlx(flax_params, size='air'): """Map Flax parameter tree to flat MLX weight dict.""" p = flax_params['params'] weights = {} # --- Global params (no transform) --- weights['k'] = np.asarray(p['k'], dtype=np.float32).reshape(()) weights['components'] = np.asarray(p['components'], dtype=np.float32) # --- RDN Backbone --- weights.update(convert_rdn_encoder(p['encoder'])) # --- SwinIR tail (pro only) --- if size == 'pro': weights.update(convert_swinir_tail(p['refine'])) # --- Hypernetwork output conv --- weights['out_conv.weight'] = conv_weight(p['out_conv']['kernel']) weights['out_conv.bias'] = np.asarray(p['out_conv']['bias']) return weights REPO_IDS = { 'air': 'prs-eth/thera-rdn-air', 'pro': 'prs-eth/thera-rdn-pro', } def download_model(size='air', filename="model.pkl", cache_dir=None): """Download model pickle from HuggingFace.""" from huggingface_hub import hf_hub_download repo_id = REPO_IDS[size] return hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) def load_pickle_with_jax(path): """Load Flax pickle (requires jax and flax installed).""" try: import jax # noqa: F401 - needed for array reconstruction import flax # noqa: F401 - needed for FrozenDict except ImportError: print("Error: jax and flax are required for loading the original weights.") print("Install them with: pip install jax flax") sys.exit(1) with open(path, 'rb') as f: checkpoint = pickle.load(f) params = checkpoint['model'] backbone = checkpoint['backbone'] size = checkpoint['size'] print(f"Loaded checkpoint: backbone={backbone}, size={size}") if backbone != 'rdn': print(f"Warning: this converter is designed for rdn, got {backbone}") return params def load_pickle_without_jax(path): """ Attempt to load Flax pickle without JAX by mocking the required classes. Falls back to JAX-based loading if this fails. """ import types class MockFrozenDict(dict): pass class MockModule(types.ModuleType): def __getattr__(self, name): return MockModule(name) class NumpyUnpickler(pickle.Unpickler): def find_class(self, module, name): if 'frozen_dict' in module and name == 'FrozenDict': return MockFrozenDict if module.startswith('jax') and name == '_reconstruct_array': # JAX arrays are reconstructed from numpy arrays + metadata def reconstruct(*args): # args typically: (numpy_array, dtype, weak_type) if len(args) >= 1 and isinstance(args[0], np.ndarray): return args[0] return np.array(args[0]) return reconstruct if module.startswith('jax'): try: return super().find_class(module, name) except (ImportError, AttributeError): return lambda *a, **kw: a[0] if a else None return super().find_class(module, name) try: with open(path, 'rb') as f: checkpoint = NumpyUnpickler(f).load() params = checkpoint['model'] backbone = checkpoint['backbone'] size = checkpoint['size'] print(f"Loaded checkpoint (no-jax mode): backbone={backbone}, size={size}") return params except Exception as e: print(f"Mock unpickle failed ({e}), falling back to JAX-based loading...") return load_pickle_with_jax(path) def save_safetensors(weights, output_path): """Save weight dict as safetensors.""" from safetensors.numpy import save_file save_file(weights, output_path) print(f"Saved MLX weights to {output_path}") def save_npz(weights, output_path): """Save weight dict as npz (fallback if safetensors not available).""" np.savez(output_path, **weights) print(f"Saved MLX weights to {output_path}") def main(): parser = argparse.ArgumentParser(description="Convert Thera RDN weights to MLX format") parser.add_argument('--model', type=str, choices=['air', 'pro'], default='air', help='Model variant (default: air)') parser.add_argument('--input', type=str, default=None, help='Path to model.pkl (downloads from HuggingFace if not provided)') parser.add_argument('--output', type=str, default=None, help='Output path (default: weights-{model}.safetensors)') parser.add_argument('--no-jax', action='store_true', help='Try to load pickle without JAX installed') args = parser.parse_args() if args.output is None: args.output = f'weights-{args.model}.safetensors' # Download if needed if args.input is None: repo = REPO_IDS[args.model] print(f"Downloading model from HuggingFace ({repo})...") pkl_path = download_model(args.model) else: pkl_path = args.input # Load if args.no_jax: flax_params = load_pickle_without_jax(pkl_path) else: flax_params = load_pickle_with_jax(pkl_path) # Convert print("Converting weights...") mlx_weights = convert_flax_to_mlx(flax_params, size=args.model) # Print summary total_params = sum(w.size for w in mlx_weights.values()) print(f"Total parameters: {total_params:,}") print(f"Weight entries: {len(mlx_weights)}") # Save output_path = args.output if output_path.endswith('.safetensors'): try: save_safetensors(mlx_weights, output_path) except ImportError: output_path = output_path.replace('.safetensors', '.npz') print("safetensors not installed, saving as npz instead") save_npz(mlx_weights, output_path) else: save_npz(mlx_weights, output_path) if __name__ == '__main__': main()