thera-mlx / convert.py
mlmPenguin's picture
Add source code
29e0144 verified
#!/usr/bin/env python3
"""
Convert Thera RDN weights (air/pro) from Flax pickle to MLX safetensors format.
Requires: jax, flax, numpy, safetensors, huggingface_hub
These are only needed for conversion, not for inference.
"""
import argparse
import os
import pickle
import sys
import numpy as np
def conv_weight(kernel):
"""Flax Conv (H, W, C_in, C_out) → MLX Conv2d (C_out, H, W, C_in)"""
return np.transpose(np.asarray(kernel), (3, 0, 1, 2))
def dense_weight(kernel):
"""Flax Dense (in, out) → MLX Linear (out, in)"""
return np.transpose(np.asarray(kernel), (1, 0))
def layernorm_params(flax_ln, prefix):
"""Map Flax LayerNorm scale/bias to MLX weight/bias."""
w = {}
w[f'{prefix}.weight'] = np.asarray(flax_ln['scale'])
if 'bias' in flax_ln:
w[f'{prefix}.bias'] = np.asarray(flax_ln['bias'])
return w
def convert_rdn_encoder(enc):
"""Convert RDN backbone weights."""
weights = {}
# SFE1 (Conv_0) and SFE2 (Conv_1)
weights['encoder.sfe1.weight'] = conv_weight(enc['Conv_0']['kernel'])
weights['encoder.sfe1.bias'] = np.asarray(enc['Conv_0']['bias'])
weights['encoder.sfe2.weight'] = conv_weight(enc['Conv_1']['kernel'])
weights['encoder.sfe2.bias'] = np.asarray(enc['Conv_1']['bias'])
# 16 Residual Dense Blocks
for i in range(16):
rdb = enc[f'RDB_{i}']
# 8 RDB_Conv layers per block
for j in range(8):
rc = rdb[f'RDB_Conv_{j}']['Conv_0']
prefix = f'encoder.rdbs.{i}.convs.{j}.conv'
weights[f'{prefix}.weight'] = conv_weight(rc['kernel'])
weights[f'{prefix}.bias'] = np.asarray(rc['bias'])
# Local fusion 1x1 conv (Conv_0 at RDB level)
lf = rdb['Conv_0']
prefix = f'encoder.rdbs.{i}.local_fusion'
weights[f'{prefix}.weight'] = conv_weight(lf['kernel'])
weights[f'{prefix}.bias'] = np.asarray(lf['bias'])
# Global Feature Fusion (Conv_2 = 1x1, Conv_3 = 3x3)
weights['encoder.gff_1x1.weight'] = conv_weight(enc['Conv_2']['kernel'])
weights['encoder.gff_1x1.bias'] = np.asarray(enc['Conv_2']['bias'])
weights['encoder.gff_3x3.weight'] = conv_weight(enc['Conv_3']['kernel'])
weights['encoder.gff_3x3.bias'] = np.asarray(enc['Conv_3']['bias'])
return weights
def convert_swinir_tail(ref):
"""Convert SwinIR tail (refine) weights for rdn-pro."""
weights = {}
# conv_first: refine/Conv_0
weights['refine.conv_first.weight'] = conv_weight(ref['Conv_0']['kernel'])
weights['refine.conv_first.bias'] = np.asarray(ref['Conv_0']['bias'])
# patch_embed_norm: refine/PatchEmbed_0/LayerNorm_0
weights.update(layernorm_params(
ref['PatchEmbed_0']['LayerNorm_0'], 'refine.patch_embed_norm'))
# RSTB layers
rstb_depths = [7, 6] # number of SwinTransformerBlocks per RSTB
for i, depth in enumerate(rstb_depths):
rstb = ref[f'RSTB_{i}']
basic = rstb['BasicLayer_0']
for j in range(depth):
stb = basic[f'SwinTransformerBlock_{j}']
mlx_prefix = f'refine.layers.{i}.blocks.{j}'
# LayerNorm_0 → norm1
weights.update(layernorm_params(
stb['LayerNorm_0'], f'{mlx_prefix}.norm1'))
# WindowAttention_0
wa = stb['WindowAttention_0']
# qkv Dense → Linear
weights[f'{mlx_prefix}.attn.qkv.weight'] = dense_weight(wa['qkv']['kernel'])
weights[f'{mlx_prefix}.attn.qkv.bias'] = np.asarray(wa['qkv']['bias'])
# proj Dense → Linear
weights[f'{mlx_prefix}.attn.proj.weight'] = dense_weight(wa['proj']['kernel'])
weights[f'{mlx_prefix}.attn.proj.bias'] = np.asarray(wa['proj']['bias'])
# relative_position_bias_table (no transform needed)
weights[f'{mlx_prefix}.attn.relative_position_bias_table'] = \
np.asarray(wa['relative_position_bias_table'])
# LayerNorm_1 → norm2
weights.update(layernorm_params(
stb['LayerNorm_1'], f'{mlx_prefix}.norm2'))
# Mlp_0 → mlp
mlp = stb['Mlp_0']
weights[f'{mlx_prefix}.mlp.fc1.weight'] = dense_weight(mlp['Dense_0']['kernel'])
weights[f'{mlx_prefix}.mlp.fc1.bias'] = np.asarray(mlp['Dense_0']['bias'])
weights[f'{mlx_prefix}.mlp.fc2.weight'] = dense_weight(mlp['Dense_1']['kernel'])
weights[f'{mlx_prefix}.mlp.fc2.bias'] = np.asarray(mlp['Dense_1']['bias'])
# RSTB conv: RSTB_{i}/Conv_0
weights[f'refine.layers.{i}.conv.weight'] = conv_weight(rstb['Conv_0']['kernel'])
weights[f'refine.layers.{i}.conv.bias'] = np.asarray(rstb['Conv_0']['bias'])
# Final norm: refine/LayerNorm_0
weights.update(layernorm_params(ref['LayerNorm_0'], 'refine.norm'))
# conv_after_body: refine/Conv_1
weights['refine.conv_after_body.weight'] = conv_weight(ref['Conv_1']['kernel'])
weights['refine.conv_after_body.bias'] = np.asarray(ref['Conv_1']['bias'])
# conv_last: refine/Conv_2
weights['refine.conv_last.weight'] = conv_weight(ref['Conv_2']['kernel'])
weights['refine.conv_last.bias'] = np.asarray(ref['Conv_2']['bias'])
return weights
def convert_flax_to_mlx(flax_params, size='air'):
"""Map Flax parameter tree to flat MLX weight dict."""
p = flax_params['params']
weights = {}
# --- Global params (no transform) ---
weights['k'] = np.asarray(p['k'], dtype=np.float32).reshape(())
weights['components'] = np.asarray(p['components'], dtype=np.float32)
# --- RDN Backbone ---
weights.update(convert_rdn_encoder(p['encoder']))
# --- SwinIR tail (pro only) ---
if size == 'pro':
weights.update(convert_swinir_tail(p['refine']))
# --- Hypernetwork output conv ---
weights['out_conv.weight'] = conv_weight(p['out_conv']['kernel'])
weights['out_conv.bias'] = np.asarray(p['out_conv']['bias'])
return weights
REPO_IDS = {
'air': 'prs-eth/thera-rdn-air',
'pro': 'prs-eth/thera-rdn-pro',
}
def download_model(size='air', filename="model.pkl", cache_dir=None):
"""Download model pickle from HuggingFace."""
from huggingface_hub import hf_hub_download
repo_id = REPO_IDS[size]
return hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
def load_pickle_with_jax(path):
"""Load Flax pickle (requires jax and flax installed)."""
try:
import jax # noqa: F401 - needed for array reconstruction
import flax # noqa: F401 - needed for FrozenDict
except ImportError:
print("Error: jax and flax are required for loading the original weights.")
print("Install them with: pip install jax flax")
sys.exit(1)
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
params = checkpoint['model']
backbone = checkpoint['backbone']
size = checkpoint['size']
print(f"Loaded checkpoint: backbone={backbone}, size={size}")
if backbone != 'rdn':
print(f"Warning: this converter is designed for rdn, got {backbone}")
return params
def load_pickle_without_jax(path):
"""
Attempt to load Flax pickle without JAX by mocking the required classes.
Falls back to JAX-based loading if this fails.
"""
import types
class MockFrozenDict(dict):
pass
class MockModule(types.ModuleType):
def __getattr__(self, name):
return MockModule(name)
class NumpyUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if 'frozen_dict' in module and name == 'FrozenDict':
return MockFrozenDict
if module.startswith('jax') and name == '_reconstruct_array':
# JAX arrays are reconstructed from numpy arrays + metadata
def reconstruct(*args):
# args typically: (numpy_array, dtype, weak_type)
if len(args) >= 1 and isinstance(args[0], np.ndarray):
return args[0]
return np.array(args[0])
return reconstruct
if module.startswith('jax'):
try:
return super().find_class(module, name)
except (ImportError, AttributeError):
return lambda *a, **kw: a[0] if a else None
return super().find_class(module, name)
try:
with open(path, 'rb') as f:
checkpoint = NumpyUnpickler(f).load()
params = checkpoint['model']
backbone = checkpoint['backbone']
size = checkpoint['size']
print(f"Loaded checkpoint (no-jax mode): backbone={backbone}, size={size}")
return params
except Exception as e:
print(f"Mock unpickle failed ({e}), falling back to JAX-based loading...")
return load_pickle_with_jax(path)
def save_safetensors(weights, output_path):
"""Save weight dict as safetensors."""
from safetensors.numpy import save_file
save_file(weights, output_path)
print(f"Saved MLX weights to {output_path}")
def save_npz(weights, output_path):
"""Save weight dict as npz (fallback if safetensors not available)."""
np.savez(output_path, **weights)
print(f"Saved MLX weights to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Convert Thera RDN weights to MLX format")
parser.add_argument('--model', type=str, choices=['air', 'pro'], default='air',
help='Model variant (default: air)')
parser.add_argument('--input', type=str, default=None,
help='Path to model.pkl (downloads from HuggingFace if not provided)')
parser.add_argument('--output', type=str, default=None,
help='Output path (default: weights-{model}.safetensors)')
parser.add_argument('--no-jax', action='store_true',
help='Try to load pickle without JAX installed')
args = parser.parse_args()
if args.output is None:
args.output = f'weights-{args.model}.safetensors'
# Download if needed
if args.input is None:
repo = REPO_IDS[args.model]
print(f"Downloading model from HuggingFace ({repo})...")
pkl_path = download_model(args.model)
else:
pkl_path = args.input
# Load
if args.no_jax:
flax_params = load_pickle_without_jax(pkl_path)
else:
flax_params = load_pickle_with_jax(pkl_path)
# Convert
print("Converting weights...")
mlx_weights = convert_flax_to_mlx(flax_params, size=args.model)
# Print summary
total_params = sum(w.size for w in mlx_weights.values())
print(f"Total parameters: {total_params:,}")
print(f"Weight entries: {len(mlx_weights)}")
# Save
output_path = args.output
if output_path.endswith('.safetensors'):
try:
save_safetensors(mlx_weights, output_path)
except ImportError:
output_path = output_path.replace('.safetensors', '.npz')
print("safetensors not installed, saving as npz instead")
save_npz(mlx_weights, output_path)
else:
save_npz(mlx_weights, output_path)
if __name__ == '__main__':
main()